-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Submit streaming code and examples #43
Changes from 6 commits
a891216
ccf78a5
6e40a9c
9d7c13d
8964974
18224b0
084d09f
ac71ac5
2633bef
25d3dc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
use clap::{arg, Parser}; | ||
use sherpa_rs::common_config::FeatureConfig; | ||
use sherpa_rs::recognizer::online_recognizer::{ | ||
OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, OnlineRecognizer, | ||
OnlineRecognizerConfig, OnlineTransducerModelConfig, OnlineZipformer2CtcModelConfig, | ||
}; | ||
use sherpa_rs::stream::online_stream::OnlineStream; | ||
|
||
/// Streaming | ||
#[derive(Parser, Debug)] | ||
#[command(version, about, long_about = None)] | ||
struct Args { | ||
/// Please provide one wave file | ||
wave_file: String, | ||
|
||
/// Path to the transducer encoder model | ||
#[arg(long, default_value = "")] | ||
encoder: String, | ||
|
||
/// Path to the transducer decoder model | ||
#[arg(long, default_value = "")] | ||
decoder: String, | ||
|
||
/// Path to the transducer joiner model | ||
#[arg(long, default_value = "")] | ||
joiner: String, | ||
|
||
/// Path to the paraformer encoder model | ||
#[arg(long, default_value = "")] | ||
paraformer_encoder: String, | ||
|
||
/// Path to the paraformer decoder model | ||
#[arg(long, default_value = "")] | ||
paraformer_decoder: String, | ||
|
||
/// Path to the zipformer2 CTC model | ||
#[arg(long, default_value = "")] | ||
zipformer2_ctc: String, | ||
|
||
/// Path to the tokens file | ||
#[arg(long, default_value = "")] | ||
tokens: String, | ||
|
||
/// Number of threads for computing | ||
#[arg(long, default_value = "1")] | ||
num_threads: i32, | ||
|
||
/// Whether to show debug message | ||
#[arg(long, default_value = "0")] | ||
debug: i32, | ||
|
||
/// Optional. Used for loading the model in a faster way | ||
#[arg(long)] | ||
model_type: Option<String>, | ||
|
||
/// Provider to use | ||
#[arg(long, default_value = "cpu")] | ||
provider: String, | ||
|
||
/// Decoding method. Possible values: greedy_search, modified_beam_search | ||
#[arg(long, default_value = "greedy_search")] | ||
decoding_method: String, | ||
|
||
/// Used only when --decoding-method is modified_beam_search | ||
#[arg(long, default_value = "4")] | ||
max_active_paths: i32, | ||
|
||
/// If not empty, path to rule fst for inverse text normalization | ||
#[arg(long, default_value = "")] | ||
rule_fsts: String, | ||
|
||
/// If not empty, path to rule fst archives for inverse text normalization | ||
#[arg(long, default_value = "")] | ||
rule_fars: String, | ||
} | ||
|
||
fn main() { | ||
// Parse command-line arguments into `Args` struct | ||
let args = Args::parse(); | ||
|
||
println!("Reading {}", args.wave_file); | ||
|
||
let (samples, sample_rate) = read_wave(&args.wave_file); | ||
|
||
println!("Initializing recognizer (may take several seconds)"); | ||
let config = OnlineRecognizerConfig { | ||
feat_config: FeatureConfig { | ||
sample_rate: 16000, | ||
feature_dim: 80, | ||
}, | ||
model_config: OnlineModelConfig { | ||
transducer: OnlineTransducerModelConfig { | ||
encoder: args.encoder, | ||
decoder: args.decoder, | ||
joiner: args.joiner, | ||
}, | ||
paraformer: OnlineParaformerModelConfig { | ||
encoder: args.paraformer_encoder, | ||
decoder: args.paraformer_decoder, | ||
}, | ||
zipformer2_ctc: OnlineZipformer2CtcModelConfig { | ||
model: args.zipformer2_ctc, | ||
}, | ||
tokens: args.tokens, | ||
num_threads: args.num_threads, | ||
provider: args.provider, | ||
debug: args.debug, | ||
model_type: args.model_type, | ||
modeling_unit: None, | ||
bpe_vocab: None, | ||
tokens_buf: None, | ||
tokens_buf_size: None, | ||
}, | ||
decoding_method: args.decoding_method, | ||
max_active_paths: args.max_active_paths, | ||
enable_endpoint: 0, | ||
rule1_min_trailing_silence: 0.0, | ||
rule2_min_trailing_silence: 0.0, | ||
rule3_min_utterance_length: 0.0, | ||
hotwords_file: "".to_string(), | ||
hotwords_score: 0.0, | ||
blank_penalty: 0.0, | ||
ctc_fst_decoder_config: OnlineCtcFstDecoderConfig { | ||
graph: "".to_string(), | ||
max_active: 0, | ||
}, | ||
rule_fsts: args.rule_fsts, | ||
rule_fars: args.rule_fars, | ||
hotwords_buf: "".to_string(), | ||
hotwords_buf_size: 0, | ||
}; | ||
|
||
let recognizer = OnlineRecognizer::new(&config); | ||
println!("Recognizer created!"); | ||
|
||
println!("Start decoding!"); | ||
let stream = OnlineStream::new(&recognizer); | ||
|
||
stream.accept_waveform(sample_rate, &samples); | ||
|
||
let tail_padding = vec![0.0; (sample_rate as f32 * 0.3) as usize]; | ||
stream.accept_waveform(sample_rate, &tail_padding); | ||
|
||
while recognizer.is_ready(&stream) { | ||
recognizer.decode(&stream); | ||
} | ||
println!("Decoding done!"); | ||
|
||
let result = recognizer.get_result(&stream); | ||
println!("{}", result.text.to_lowercase()); | ||
println!( | ||
"Wave duration: {} seconds", | ||
samples.len() as f32 / sample_rate as f32 | ||
); | ||
} | ||
|
||
/// 读取 WAV 文件并返回样本和采样率 | ||
/// | ||
/// # 参数 | ||
/// | ||
/// * `filename` - WAV 文件的路径 | ||
/// | ||
/// # 返回 | ||
/// | ||
/// * `samples` - 样本数据 | ||
/// * `sample_rate` - 采样率 | ||
fn read_wave(filename: &str) -> (Vec<f32>, i32) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use internal function for that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is only for testing, and the code logic is not suitable for the framework There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's useful to keep the example short and simple as possible with only the relevant info to dive in quickly. That's why it's useful to use the internal function like in other examples and even i don't use clap in most of the examples when not necessarily There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This place is just to do safe packaging, not to do higher level packaging, and higher level to use those that were made before. The structs you made before could be structs that were wrapped with this safe. |
||
let mut reader = hound::WavReader::open(filename).expect("Failed to open WAV file"); | ||
let spec = reader.spec(); | ||
|
||
if spec.sample_format != hound::SampleFormat::Int { | ||
panic!("Support only PCM format. Given: {:?}", spec.sample_format); | ||
} | ||
|
||
if spec.channels != 1 { | ||
panic!("Support only 1 channel wave file. Given: {}", spec.channels); | ||
} | ||
|
||
if spec.bits_per_sample != 16 { | ||
panic!( | ||
"Support only 16-bit per sample. Given: {}", | ||
spec.bits_per_sample | ||
); | ||
} | ||
|
||
let samples: Vec<f32> = reader | ||
.samples::<i16>() | ||
.map(|s| s.unwrap() as f32 / 32768.0) | ||
.collect(); | ||
|
||
(samples, spec.sample_rate as i32) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/// Configuration for the feature extractor | ||
pub struct FeatureConfig { | ||
/// Sample rate expected by the model. It is 16000 for all | ||
/// pre-trained models provided by us | ||
pub sample_rate: i32, | ||
/// Feature dimension expected by the model. It is 80 for all | ||
/// pre-trained models provided by us | ||
pub feature_dim: i32, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
//! Speech recognition with [Next-gen Kaldi]. | ||
//! | ||
//! [sherpa-onnx] is an open-source speech recognition framework for [Next-gen Kaldi]. | ||
//! It depends only on [onnxruntime], supporting both streaming and non-streaming | ||
//! speech recognition. | ||
//! | ||
//! It does not need to access the network during recognition and everything | ||
//! runs locally. | ||
//! | ||
//! It supports a variety of platforms, such as Linux (x86_64, aarch64, arm), | ||
//! Windows (x86_64, x86), macOS (x86_64, arm64), etc. | ||
//! | ||
//! Usage examples: | ||
//! | ||
//! 1. Real-time speech recognition from a microphone | ||
//! | ||
//! Please see | ||
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/real-time-speech-recognition-from-microphone | ||
//! | ||
//! 2. Decode files using a non-streaming model | ||
//! | ||
//! Please see | ||
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-decode-files | ||
//! | ||
//! 3. Decode files using a streaming model | ||
//! | ||
//! Please see | ||
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/streaming-decode-files | ||
//! | ||
//! 4. Convert text to speech using a non-streaming model | ||
//! | ||
//! Please see | ||
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-tts | ||
//! | ||
//! [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx | ||
//! [onnxruntime]: https://github.com/microsoft/onnxruntime | ||
//! [Next-gen Kaldi]: https://github.com/k2-fsa/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove all comments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
|
||
pub mod offline_recognizer; | ||
pub mod online_recognizer; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
use crate::common_config::FeatureConfig; | ||
use sherpa_rs_sys::SherpaOnnxOfflineRecognizer; | ||
|
||
/// Configuration for offline/non-streaming transducer. | ||
/// | ||
/// Please refer to | ||
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html | ||
/// to download pre-trained models | ||
struct OfflineTransducerModelConfig { | ||
encoder: String, // Path to the encoder model, i.e., encoder.onnx or encoder.int8.onnx | ||
decoder: String, // Path to the decoder model | ||
joiner: String, // Path to the joiner model | ||
} | ||
|
||
/// Configuration for offline/non-streaming paraformer. | ||
/// | ||
/// please refer to | ||
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html | ||
/// to download pre-trained models | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove most of the comments in the files. you can add link to specific sherpa docs on top of the file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider moving it to the entire crate document There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to avoid maintain documentation as possible. We'll handle that in another pr/issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
struct OfflineParaformerModelConfig { | ||
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx | ||
} | ||
|
||
/// Configuration for offline/non-streaming NeMo CTC models. | ||
/// | ||
/// Please refer to | ||
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html | ||
/// to download pre-trained models | ||
struct OfflineNemoEncDecCtcModelConfig { | ||
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx | ||
} | ||
|
||
struct OfflineWhisperModelConfig { | ||
encoder: String, | ||
decoder: String, | ||
language: String, | ||
task: String, | ||
tail_paddings: i32, | ||
} | ||
|
||
struct OfflineTdnnModelConfig { | ||
model: String, | ||
} | ||
|
||
struct OfflineSenseVoiceModelConfig { | ||
model: String, | ||
language: String, | ||
use_inverse_text_normalization: i32, | ||
} | ||
|
||
/// Configuration for offline LM. | ||
struct OfflineLMConfig { | ||
model: String, // Path to the model | ||
scale: f32, // scale for LM score | ||
} | ||
|
||
struct OfflineModelConfig { | ||
transducer: OfflineTransducerModelConfig, | ||
paraformer: OfflineParaformerModelConfig, | ||
nemo_ctc: OfflineNemoEncDecCtcModelConfig, | ||
whisper: OfflineWhisperModelConfig, | ||
tdnn: OfflineTdnnModelConfig, | ||
sense_voice: OfflineSenseVoiceModelConfig, | ||
tokens: String, // Path to tokens.txt | ||
|
||
// Number of threads to use for neural network computation | ||
num_threads: i32, | ||
|
||
// 1 to print model meta information while loading | ||
debug: i32, | ||
|
||
// Optional. Valid values: cpu, cuda, coreml | ||
provider: String, | ||
|
||
// Optional. Specify it for faster model initialization. | ||
model_type: String, | ||
|
||
modeling_unit: String, // Optional. cjkchar, bpe, cjkchar+bpe | ||
bpe_vocab: String, // Optional. | ||
tele_speech_ctc: String, // Optional. | ||
} | ||
|
||
/// Configuration for the offline/non-streaming recognizer. | ||
struct OfflineRecognizerConfig { | ||
feat_config: FeatureConfig, | ||
model_config: OfflineModelConfig, | ||
lm_config: OfflineLMConfig, | ||
|
||
// Valid decoding method: greedy_search, modified_beam_search | ||
decoding_method: String, | ||
|
||
// Used only when DecodingMethod is modified_beam_search. | ||
max_active_paths: i32, | ||
hotwords_file: String, | ||
hotwords_score: f32, | ||
blank_penalty: f32, | ||
rule_fsts: String, | ||
rule_fars: String, | ||
} | ||
|
||
/// It wraps a pointer from C | ||
struct OfflineRecognizer { | ||
pointer: *const SherpaOnnxOfflineRecognizer, | ||
} | ||
|
||
/// It contains recognition result of an offline stream. | ||
struct OfflineRecognizerResult { | ||
text: String, | ||
tokens: Vec<String>, | ||
timestamps: Vec<f32>, | ||
lang: String, | ||
emotion: String, | ||
event: String, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify configs with impl default like in other examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rust's default means that there can be a default case, but this configuration obviously doesn't. The best practice should be to provide a series of new methods, such as new_ transducer, new_paraformer, etc. default, there really isn't a particular default standard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but even using new methods for the structus will be too much expressive for irrelevant fields. so maybe it's worth using default and the examples will show how to use it correctly + sherpa-onnx will throw error with info