Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submit streaming code and examples #43

Closed
wants to merge 10 commits into from
192 changes: 192 additions & 0 deletions examples/streaming_decode_files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use clap::{arg, Parser};
use sherpa_rs::common_config::FeatureConfig;
use sherpa_rs::recognizer::online_recognizer::{
OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, OnlineRecognizer,
OnlineRecognizerConfig, OnlineTransducerModelConfig, OnlineZipformer2CtcModelConfig,
};
use sherpa_rs::stream::online_stream::OnlineStream;

/// Streaming
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Please provide one wave file
wave_file: String,

/// Path to the transducer encoder model
#[arg(long, default_value = "")]
encoder: String,

/// Path to the transducer decoder model
#[arg(long, default_value = "")]
decoder: String,

/// Path to the transducer joiner model
#[arg(long, default_value = "")]
joiner: String,

/// Path to the paraformer encoder model
#[arg(long, default_value = "")]
paraformer_encoder: String,

/// Path to the paraformer decoder model
#[arg(long, default_value = "")]
paraformer_decoder: String,

/// Path to the zipformer2 CTC model
#[arg(long, default_value = "")]
zipformer2_ctc: String,

/// Path to the tokens file
#[arg(long, default_value = "")]
tokens: String,

/// Number of threads for computing
#[arg(long, default_value = "1")]
num_threads: i32,

/// Whether to show debug message
#[arg(long, default_value = "0")]
debug: i32,

/// Optional. Used for loading the model in a faster way
#[arg(long)]
model_type: Option<String>,

/// Provider to use
#[arg(long, default_value = "cpu")]
provider: String,

/// Decoding method. Possible values: greedy_search, modified_beam_search
#[arg(long, default_value = "greedy_search")]
decoding_method: String,

/// Used only when --decoding-method is modified_beam_search
#[arg(long, default_value = "4")]
max_active_paths: i32,

/// If not empty, path to rule fst for inverse text normalization
#[arg(long, default_value = "")]
rule_fsts: String,

/// If not empty, path to rule fst archives for inverse text normalization
#[arg(long, default_value = "")]
rule_fars: String,
}

fn main() {
// Parse command-line arguments into `Args` struct
let args = Args::parse();

println!("Reading {}", args.wave_file);

let (samples, sample_rate) = read_wave(&args.wave_file);

println!("Initializing recognizer (may take several seconds)");
let config = OnlineRecognizerConfig {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify configs with impl default like in other examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rust's default means that there can be a default case, but this configuration obviously doesn't. The best practice should be to provide a series of new methods, such as new_ transducer, new_paraformer, etc. default, there really isn't a particular default standard.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but even using new methods for the structus will be too much expressive for irrelevant fields. so maybe it's worth using default and the examples will show how to use it correctly + sherpa-onnx will throw error with info

feat_config: FeatureConfig {
sample_rate: 16000,
feature_dim: 80,
},
model_config: OnlineModelConfig {
transducer: OnlineTransducerModelConfig {
encoder: args.encoder,
decoder: args.decoder,
joiner: args.joiner,
},
paraformer: OnlineParaformerModelConfig {
encoder: args.paraformer_encoder,
decoder: args.paraformer_decoder,
},
zipformer2_ctc: OnlineZipformer2CtcModelConfig {
model: args.zipformer2_ctc,
},
tokens: args.tokens,
num_threads: args.num_threads,
provider: args.provider,
debug: args.debug,
model_type: args.model_type,
modeling_unit: None,
bpe_vocab: None,
tokens_buf: None,
tokens_buf_size: None,
},
decoding_method: args.decoding_method,
max_active_paths: args.max_active_paths,
enable_endpoint: 0,
rule1_min_trailing_silence: 0.0,
rule2_min_trailing_silence: 0.0,
rule3_min_utterance_length: 0.0,
hotwords_file: "".to_string(),
hotwords_score: 0.0,
blank_penalty: 0.0,
ctc_fst_decoder_config: OnlineCtcFstDecoderConfig {
graph: "".to_string(),
max_active: 0,
},
rule_fsts: args.rule_fsts,
rule_fars: args.rule_fars,
hotwords_buf: "".to_string(),
hotwords_buf_size: 0,
};

let recognizer = OnlineRecognizer::new(&config);
println!("Recognizer created!");

println!("Start decoding!");
let stream = OnlineStream::new(&recognizer);

stream.accept_waveform(sample_rate, &samples);

let tail_padding = vec![0.0; (sample_rate as f32 * 0.3) as usize];
stream.accept_waveform(sample_rate, &tail_padding);

while recognizer.is_ready(&stream) {
recognizer.decode(&stream);
}
println!("Decoding done!");

let result = recognizer.get_result(&stream);
println!("{}", result.text.to_lowercase());
println!(
"Wave duration: {} seconds",
samples.len() as f32 / sample_rate as f32
);
}

/// 读取 WAV 文件并返回样本和采样率
///
/// # 参数
///
/// * `filename` - WAV 文件的路径
///
/// # 返回
///
/// * `samples` - 样本数据
/// * `sample_rate` - 采样率
fn read_wave(filename: &str) -> (Vec<f32>, i32) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use internal function for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use internal function for that

This is only for testing, and the code logic is not suitable for the framework

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's useful to keep the example short and simple as possible with only the relevant info to dive in quickly. That's why it's useful to use the internal function like in other examples and even i don't use clap in most of the examples when not necessarily

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place is just to do safe packaging, not to do higher level packaging, and higher level to use those that were made before. The structs you made before could be structs that were wrapped with this safe.

let mut reader = hound::WavReader::open(filename).expect("Failed to open WAV file");
let spec = reader.spec();

if spec.sample_format != hound::SampleFormat::Int {
panic!("Support only PCM format. Given: {:?}", spec.sample_format);
}

if spec.channels != 1 {
panic!("Support only 1 channel wave file. Given: {}", spec.channels);
}

if spec.bits_per_sample != 16 {
panic!(
"Support only 16-bit per sample. Given: {}",
spec.bits_per_sample
);
}

let samples: Vec<f32> = reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0)
.collect();

(samples, spec.sample_rate as i32)
}
9 changes: 9 additions & 0 deletions src/common_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/// Configuration for the feature extractor
pub struct FeatureConfig {
/// Sample rate expected by the model. It is 16000 for all
/// pre-trained models provided by us
pub sample_rate: i32,
/// Feature dimension expected by the model. It is 80 for all
/// pre-trained models provided by us
pub feature_dim: i32,
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
pub mod audio_tag;
pub mod common_config;
pub mod diarize;
pub mod embedding_manager;
pub mod keyword_spot;
pub mod language_id;
pub mod punctuate;
pub mod recognizer;
pub mod speaker_id;
pub mod stream;
pub mod vad;
pub mod whisper;
pub mod zipformer;
Expand Down
40 changes: 40 additions & 0 deletions src/recognizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//! Speech recognition with [Next-gen Kaldi].
//!
//! [sherpa-onnx] is an open-source speech recognition framework for [Next-gen Kaldi].
//! It depends only on [onnxruntime], supporting both streaming and non-streaming
//! speech recognition.
//!
//! It does not need to access the network during recognition and everything
//! runs locally.
//!
//! It supports a variety of platforms, such as Linux (x86_64, aarch64, arm),
//! Windows (x86_64, x86), macOS (x86_64, arm64), etc.
//!
//! Usage examples:
//!
//! 1. Real-time speech recognition from a microphone
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/real-time-speech-recognition-from-microphone
//!
//! 2. Decode files using a non-streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-decode-files
//!
//! 3. Decode files using a streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/streaming-decode-files
//!
//! 4. Convert text to speech using a non-streaming model
//!
//! Please see
//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-tts
//!
//! [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
//! [onnxruntime]: https://github.com/microsoft/onnxruntime
//! [Next-gen Kaldi]: https://github.com/k2-fsa/
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok


pub mod offline_recognizer;
pub mod online_recognizer;
114 changes: 114 additions & 0 deletions src/recognizer/offline_recognizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::common_config::FeatureConfig;
use sherpa_rs_sys::SherpaOnnxOfflineRecognizer;

/// Configuration for offline/non-streaming transducer.
///
/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
/// to download pre-trained models
struct OfflineTransducerModelConfig {
encoder: String, // Path to the encoder model, i.e., encoder.onnx or encoder.int8.onnx
decoder: String, // Path to the decoder model
joiner: String, // Path to the joiner model
}

/// Configuration for offline/non-streaming paraformer.
///
/// please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
/// to download pre-trained models
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove most of the comments in the files. you can add link to specific sherpa docs on top of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove most of the comments in the files. you can add link to specific sherpa docs on top of the file

Consider moving it to the entire crate document

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to avoid maintain documentation as possible. We'll handle that in another pr/issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

struct OfflineParaformerModelConfig {
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx
}

/// Configuration for offline/non-streaming NeMo CTC models.
///
/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
/// to download pre-trained models
struct OfflineNemoEncDecCtcModelConfig {
model: String, // Path to the model, e.g., model.onnx or model.int8.onnx
}

struct OfflineWhisperModelConfig {
encoder: String,
decoder: String,
language: String,
task: String,
tail_paddings: i32,
}

struct OfflineTdnnModelConfig {
model: String,
}

struct OfflineSenseVoiceModelConfig {
model: String,
language: String,
use_inverse_text_normalization: i32,
}

/// Configuration for offline LM.
struct OfflineLMConfig {
model: String, // Path to the model
scale: f32, // scale for LM score
}

struct OfflineModelConfig {
transducer: OfflineTransducerModelConfig,
paraformer: OfflineParaformerModelConfig,
nemo_ctc: OfflineNemoEncDecCtcModelConfig,
whisper: OfflineWhisperModelConfig,
tdnn: OfflineTdnnModelConfig,
sense_voice: OfflineSenseVoiceModelConfig,
tokens: String, // Path to tokens.txt

// Number of threads to use for neural network computation
num_threads: i32,

// 1 to print model meta information while loading
debug: i32,

// Optional. Valid values: cpu, cuda, coreml
provider: String,

// Optional. Specify it for faster model initialization.
model_type: String,

modeling_unit: String, // Optional. cjkchar, bpe, cjkchar+bpe
bpe_vocab: String, // Optional.
tele_speech_ctc: String, // Optional.
}

/// Configuration for the offline/non-streaming recognizer.
struct OfflineRecognizerConfig {
feat_config: FeatureConfig,
model_config: OfflineModelConfig,
lm_config: OfflineLMConfig,

// Valid decoding method: greedy_search, modified_beam_search
decoding_method: String,

// Used only when DecodingMethod is modified_beam_search.
max_active_paths: i32,
hotwords_file: String,
hotwords_score: f32,
blank_penalty: f32,
rule_fsts: String,
rule_fars: String,
}

/// It wraps a pointer from C
struct OfflineRecognizer {
pointer: *const SherpaOnnxOfflineRecognizer,
}

/// It contains recognition result of an offline stream.
struct OfflineRecognizerResult {
text: String,
tokens: Vec<String>,
timestamps: Vec<f32>,
lang: String,
emotion: String,
event: String,
}
Loading
Loading