Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submit streaming code and examples #42

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions examples/streaming_decode_files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use clap::{arg, Parser};
use sherpa_rs::common_config::FeatureConfig;
use sherpa_rs::recognizer::online_recognizer::{
OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, OnlineRecognizer,
OnlineRecognizerConfig, OnlineTransducerModelConfig, OnlineZipformer2CtcModelConfig,
};
use sherpa_rs::stream::online_stream::OnlineStream;

/// Streaming
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Please provide one wave file
wave_file: String,

/// Path to the transducer encoder model
#[arg(long, default_value = "")]
encoder: String,

/// Path to the transducer decoder model
#[arg(long, default_value = "")]
decoder: String,

/// Path to the transducer joiner model
#[arg(long, default_value = "")]
joiner: String,

/// Path to the paraformer encoder model
#[arg(long, default_value = "")]
paraformer_encoder: String,

/// Path to the paraformer decoder model
#[arg(long, default_value = "")]
paraformer_decoder: String,

/// Path to the zipformer2 CTC model
#[arg(long, default_value = "")]
zipformer2_ctc: String,

/// Path to the tokens file
#[arg(long, default_value = "")]
tokens: String,

/// Number of threads for computing
#[arg(long, default_value = "1")]
num_threads: i32,

/// Whether to show debug message
#[arg(long, default_value = "0")]
debug: i32,

/// Optional. Used for loading the model in a faster way
#[arg(long)]
model_type: Option<String>,

/// Provider to use
#[arg(long, default_value = "cpu")]
provider: String,

/// Decoding method. Possible values: greedy_search, modified_beam_search
#[arg(long, default_value = "greedy_search")]
decoding_method: String,

/// Used only when --decoding-method is modified_beam_search
#[arg(long, default_value = "4")]
max_active_paths: i32,

/// If not empty, path to rule fst for inverse text normalization
#[arg(long, default_value = "")]
rule_fsts: String,

/// If not empty, path to rule fst archives for inverse text normalization
#[arg(long, default_value = "")]
rule_fars: String,
}

fn main() {
// Parse command-line arguments into `Args` struct
let args = Args::parse();

println!("Reading {}", args.wave_file);

let (samples, sample_rate) = read_wave(&args.wave_file);

println!("Initializing recognizer (may take several seconds)");
let config = OnlineRecognizerConfig {
feat_config: FeatureConfig {
sample_rate: 16000,
feature_dim: 80,
},
model_config: OnlineModelConfig {
transducer: OnlineTransducerModelConfig {
encoder: args.encoder,
decoder: args.decoder,
joiner: args.joiner,
},
paraformer: OnlineParaformerModelConfig {
encoder: args.paraformer_encoder,
decoder: args.paraformer_decoder,
},
zipformer2_ctc: OnlineZipformer2CtcModelConfig {
model: args.zipformer2_ctc,
},
tokens: args.tokens,
num_threads: args.num_threads,
provider: args.provider,
debug: args.debug,
model_type: args.model_type,
modeling_unit: None,
bpe_vocab: None,
tokens_buf: None,
tokens_buf_size: None,
},
decoding_method: args.decoding_method,
max_active_paths: args.max_active_paths,
enable_endpoint: 0,
rule1_min_trailing_silence: 0.0,
rule2_min_trailing_silence: 0.0,
rule3_min_utterance_length: 0.0,
hotwords_file: "".to_string(),
hotwords_score: 0.0,
blank_penalty: 0.0,
ctc_fst_decoder_config: OnlineCtcFstDecoderConfig {
graph: "".to_string(),
max_active: 0,
},
rule_fsts: args.rule_fsts,
rule_fars: args.rule_fars,
hotwords_buf: "".to_string(),
hotwords_buf_size: 0,
};

let recognizer = OnlineRecognizer::new(&config);
println!("Recognizer created!");

println!("Start decoding!");
let stream = OnlineStream::new(&recognizer);

stream.accept_waveform(sample_rate, &samples);

let tail_padding = vec![0.0; (sample_rate as f32 * 0.3) as usize];
stream.accept_waveform(sample_rate, &tail_padding);

while recognizer.is_ready(&stream) {
recognizer.decode(&stream);
}
println!("Decoding done!");

let result = recognizer.get_result(&stream);
println!("{}", result.text.to_lowercase());
println!(
"Wave duration: {} seconds",
samples.len() as f32 / sample_rate as f32
);
}

/// 读取 WAV 文件并返回样本和采样率
///
/// # 参数
///
/// * `filename` - WAV 文件的路径
///
/// # 返回
///
/// * `samples` - 样本数据
/// * `sample_rate` - 采样率
fn read_wave(filename: &str) -> (Vec<f32>, i32) {
let mut reader = hound::WavReader::open(filename).expect("Failed to open WAV file");
let spec = reader.spec();

if spec.sample_format != hound::SampleFormat::Int {
panic!("Support only PCM format. Given: {:?}", spec.sample_format);
}

if spec.channels != 1 {
panic!("Support only 1 channel wave file. Given: {}", spec.channels);
}

if spec.bits_per_sample != 16 {
panic!(
"Support only 16-bit per sample. Given: {}",
spec.bits_per_sample
);
}

let samples: Vec<f32> = reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0)
.collect();

(samples, spec.sample_rate as i32)
}
9 changes: 9 additions & 0 deletions src/common_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/// Configuration for the feature extractor
pub struct FeatureConfig {
/// Sample rate expected by the model. It is 16000 for all
/// pre-trained models provided by us
pub sample_rate: i32,
/// Feature dimension expected by the model. It is 80 for all
/// pre-trained models provided by us
pub feature_dim: i32,
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
pub mod audio_tag;
pub mod common_config;
pub mod diarize;
pub mod embedding_manager;
pub mod keyword_spot;
pub mod language_id;
pub mod punctuate;
pub mod recognizer;
pub mod speaker_id;
pub mod stream;
pub mod vad;
pub mod whisper;
pub mod zipformer;
Expand Down
40 changes: 40 additions & 0 deletions src/recognizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//! Speech recognition with [Next-gen Kaldi].
//!
//! [sherpa-onnx] is an open-source speech recognition framework for [Next-gen Kaldi].
//! It depends only on [onnxruntime], supporting both streaming and non-streaming
//! speech recognition.
//!
//! It does not need to access the network during recognition and everything
//! runs locally.
//!
//! It supports a variety of platforms, such as Linux (x86_64, aarch64, arm),
//! Windows (x86_64, x86), macOS (x86_64, arm64), etc.
//!
//! Usage examples:
//!
//! 1. Real-time speech recognition from a microphone
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/real-time-speech-recognition-from-microphone
//!
//! 2. Decode files using a non-streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-decode-files
//!
//! 3. Decode files using a streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/streaming-decode-files
//!
//! 4. Convert text to speech using a non-streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-tts
//!
//! [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
//! [onnxruntime]: https://github.com/microsoft/onnxruntime
//! [Next-gen Kaldi]: https://github.com/k2-fsa/

pub mod offline_recognizer;
pub mod online_recognizer;
114 changes: 114 additions & 0 deletions src/recognizer/offline_recognizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::common_config::FeatureConfig;
use sherpa_rs_sys::SherpaOnnxOfflineRecognizer;

/// Configuration for offline/non-streaming transducer.
///
/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
/// to download pre-trained models
struct OfflineTransducerModelConfig {
encoder: String, // Path to the encoder model, i.e., encoder.onnx or encoder.int8.onnx
decoder: String, // Path to the decoder model
joiner: String, // Path to the joiner model
}

/// Configuration for offline/non-streaming paraformer.
///
/// please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
/// to download pre-trained models
struct OfflineParaformerModelConfig {
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx
}

/// Configuration for offline/non-streaming NeMo CTC models.
///
/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
/// to download pre-trained models
struct OfflineNemoEncDecCtcModelConfig {
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx
}

struct OfflineWhisperModelConfig {
encoder: String,
decoder: String,
language: String,
task: String,
tail_paddings: i32,
}

struct OfflineTdnnModelConfig {
model: String,
}

struct OfflineSenseVoiceModelConfig {
model: String,
language: String,
use_inverse_text_normalization: i32,
}

/// Configuration for offline LM.
struct OfflineLMConfig {
model: String, // Path to the model
scale: f32, // scale for LM score
}

struct OfflineModelConfig {
transducer: OfflineTransducerModelConfig,
paraformer: OfflineParaformerModelConfig,
nemo_ctc: OfflineNemoEncDecCtcModelConfig,
whisper: OfflineWhisperModelConfig,
tdnn: OfflineTdnnModelConfig,
sense_voice: OfflineSenseVoiceModelConfig,
tokens: String, // Path to tokens.txt

// Number of threads to use for neural network computation
num_threads: i32,

// 1 to print model meta information while loading
debug: i32,

// Optional. Valid values: cpu, cuda, coreml
provider: String,

// Optional. Specify it for faster model initialization.
model_type: String,

modeling_unit: String, // Optional. cjkchar, bpe, cjkchar+bpe
bpe_vocab: String, // Optional.
tele_speech_ctc: String, // Optional.
}

/// Configuration for the offline/non-streaming recognizer.
struct OfflineRecognizerConfig {
feat_config: FeatureConfig,
model_config: OfflineModelConfig,
lm_config: OfflineLMConfig,

// Valid decoding method: greedy_search, modified_beam_search
decoding_method: String,

// Used only when DecodingMethod is modified_beam_search.
max_active_paths: i32,
hotwords_file: String,
hotwords_score: f32,
blank_penalty: f32,
rule_fsts: String,
rule_fars: String,
}

/// It wraps a pointer from C
struct OfflineRecognizer {
pointer: *const SherpaOnnxOfflineRecognizer,
}

/// It contains recognition result of an offline stream.
struct OfflineRecognizerResult {
text: String,
tokens: Vec<String>,
timestamps: Vec<f32>,
lang: String,
emotion: String,
event: String,
}
Loading
Loading