From a8912164f7dadaf4830317806606c89eaad40142 Mon Sep 17 00:00:00 2001 From: gonghuijun <303690073@qq.com> Date: Mon, 21 Oct 2024 13:52:32 +0800 Subject: [PATCH 1/4] example tts add param tts_rule_fsts --- examples/tts.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/tts.rs b/examples/tts.rs index 21015f9..6e27352 100644 --- a/examples/tts.rs +++ b/examples/tts.rs @@ -47,6 +47,9 @@ struct Args { #[arg(long)] lexicon: Option, + #[arg(long)] + tts_rule_fsts: Option, + #[arg(long)] sid: Option, @@ -85,6 +88,7 @@ fn main() { let tts_config = sherpa_rs::tts::OfflineTtsConfig { model: args.model, max_num_sentences, + rule_fsts: args.tts_rule_fsts.unwrap_or_default(), ..Default::default() }; let mut tts = sherpa_rs::tts::OfflineTts::new(tts_config, vits_config); From ccf78a5f14bb599729942a7263324a53d5b533a9 Mon Sep 17 00:00:00 2001 From: gonghuijun <303690073@qq.com> Date: Mon, 21 Oct 2024 13:55:56 +0800 Subject: [PATCH 2/4] sherpa-rs-sys build.rs compile_error! add target-os cfg. (macos can compile success.) --- sys/build.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sys/build.rs b/sys/build.rs index fa20b48..4b0407a 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -5,7 +5,11 @@ use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; -#[cfg(all(feature = "download-binaries", feature = "cuda"))] +#[cfg(all( + feature = "download-binaries", + feature = "cuda", + any(target_os = "windows", target_os = "linux") +))] compile_error!( "Features 'download-binaries' and 'cuda' cannot be enabled simultaneously.\n\ If you wish to use the 'cuda' feature, disable the 'download-binaries' feature by setting `default-features = false` in Cargo.toml.\n\ @@ -13,7 +17,11 @@ compile_error!( cargo build --features cuda --no-default-features" ); -#[cfg(all(feature = "download-binaries", feature = "directml"))] +#[cfg(all( + feature = "download-binaries", + feature = "directml", + any(target_os = "windows") +))] compile_error!( "Features 'download-binaries' and 'directml' cannot be enabled simultaneously.\n\ If you wish to use the 'directml' feature, disable the 'download-binaries' feature by setting `default-features = false` in Cargo.toml.\n\ From 8964974b92d3af7b9255461ca1a12709315df5d1 Mon Sep 17 00:00:00 2001 From: gonghuijun <303690073@qq.com> Date: Mon, 28 Oct 2024 23:02:44 +0800 Subject: [PATCH 3/4] Submit streaming code and examples --- examples/streaming_decode_files.rs | 192 +++++++++++++++++ src/lib.rs | 2 + src/recognizer.rs | 40 ++++ src/recognizer/offline_recognizer.rs | 0 src/recognizer/online_recognizer.rs | 309 +++++++++++++++++++++++++++ src/stream.rs | 2 + src/stream/offline_stream.rs | 0 src/stream/online_stream.rs | 78 +++++++ src/utils.rs | 1 + 9 files changed, 624 insertions(+) create mode 100644 examples/streaming_decode_files.rs create mode 100644 src/recognizer.rs create mode 100644 src/recognizer/offline_recognizer.rs create mode 100644 src/recognizer/online_recognizer.rs create mode 100644 src/stream.rs create mode 100644 src/stream/offline_stream.rs create mode 100644 src/stream/online_stream.rs diff --git a/examples/streaming_decode_files.rs b/examples/streaming_decode_files.rs new file mode 100644 index 0000000..36bb18a --- /dev/null +++ b/examples/streaming_decode_files.rs @@ -0,0 +1,192 @@ +use clap::{arg, Parser}; +use sherpa_rs::recognizer::online_recognizer::{ + FeatureConfig, OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, + OnlineRecognizer, OnlineRecognizerConfig, OnlineTransducerModelConfig, + OnlineZipformer2CtcModelConfig, +}; +use sherpa_rs::stream::online_stream::OnlineStream; + +/// Streaming +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// Please provide one wave file + wave_file: String, + + /// Path to the transducer encoder model + #[arg(long, default_value = "")] + encoder: String, + + /// Path to the transducer decoder model + #[arg(long, default_value = "")] + decoder: String, + + /// Path to the transducer joiner model + #[arg(long, default_value = "")] + joiner: String, + + /// Path to the paraformer encoder model + #[arg(long, default_value = "")] + paraformer_encoder: String, + + /// Path to the paraformer decoder model + #[arg(long, default_value = "")] + paraformer_decoder: String, + + /// Path to the zipformer2 CTC model + #[arg(long, default_value = "")] + zipformer2_ctc: String, + + /// Path to the tokens file + #[arg(long, default_value = "")] + tokens: String, + + /// Number of threads for computing + #[arg(long, default_value = "1")] + num_threads: i32, + + /// Whether to show debug message + #[arg(long, default_value = "0")] + debug: i32, + + /// Optional. Used for loading the model in a faster way + #[arg(long)] + model_type: Option, + + /// Provider to use + #[arg(long, default_value = "cpu")] + provider: String, + + /// Decoding method. Possible values: greedy_search, modified_beam_search + #[arg(long, default_value = "greedy_search")] + decoding_method: String, + + /// Used only when --decoding-method is modified_beam_search + #[arg(long, default_value = "4")] + max_active_paths: i32, + + /// If not empty, path to rule fst for inverse text normalization + #[arg(long, default_value = "")] + rule_fsts: String, + + /// If not empty, path to rule fst archives for inverse text normalization + #[arg(long, default_value = "")] + rule_fars: String, +} + +fn main() { + // Parse command-line arguments into `Args` struct + let args = Args::parse(); + + println!("Reading {}", args.wave_file); + + let (samples, sample_rate) = read_wave(&args.wave_file); + + println!("Initializing recognizer (may take several seconds)"); + let config = OnlineRecognizerConfig { + feat_config: FeatureConfig { + sample_rate: 16000, + feature_dim: 80, + }, + model_config: OnlineModelConfig { + transducer: OnlineTransducerModelConfig { + encoder: args.encoder, + decoder: args.decoder, + joiner: args.joiner, + }, + paraformer: OnlineParaformerModelConfig { + encoder: args.paraformer_encoder, + decoder: args.paraformer_decoder, + }, + zipformer2_ctc: OnlineZipformer2CtcModelConfig { + model: args.zipformer2_ctc, + }, + tokens: args.tokens, + num_threads: args.num_threads, + provider: args.provider, + debug: args.debug, + model_type: args.model_type, + modeling_unit: None, + bpe_vocab: None, + tokens_buf: None, + tokens_buf_size: None, + }, + decoding_method: args.decoding_method, + max_active_paths: args.max_active_paths, + enable_endpoint: 0, + rule1_min_trailing_silence: 0.0, + rule2_min_trailing_silence: 0.0, + rule3_min_utterance_length: 0.0, + hotwords_file: "".to_string(), + hotwords_score: 0.0, + blank_penalty: 0.0, + ctc_fst_decoder_config: OnlineCtcFstDecoderConfig { + graph: "".to_string(), + max_active: 0, + }, + rule_fsts: args.rule_fsts, + rule_fars: args.rule_fars, + hotwords_buf: "".to_string(), + hotwords_buf_size: 0, + }; + + let recognizer = OnlineRecognizer::new(&config); + println!("Recognizer created!"); + + println!("Start decoding!"); + let stream = OnlineStream::new(&recognizer); + + stream.accept_waveform(sample_rate, &samples); + + let tail_padding = vec![0.0; (sample_rate as f32 * 0.3) as usize]; + stream.accept_waveform(sample_rate, &tail_padding); + + while recognizer.is_ready(&stream) { + recognizer.decode(&stream); + } + println!("Decoding done!"); + + let result = recognizer.get_result(&stream); + println!("{}", result.text.to_lowercase()); + println!( + "Wave duration: {} seconds", + samples.len() as f32 / sample_rate as f32 + ); +} + +/// 读取 WAV 文件并返回样本和采样率 +/// +/// # 参数 +/// +/// * `filename` - WAV 文件的路径 +/// +/// # 返回 +/// +/// * `samples` - 样本数据 +/// * `sample_rate` - 采样率 +fn read_wave(filename: &str) -> (Vec, i32) { + let mut reader = hound::WavReader::open(filename).expect("Failed to open WAV file"); + let spec = reader.spec(); + + if spec.sample_format != hound::SampleFormat::Int { + panic!("Support only PCM format. Given: {:?}", spec.sample_format); + } + + if spec.channels != 1 { + panic!("Support only 1 channel wave file. Given: {}", spec.channels); + } + + if spec.bits_per_sample != 16 { + panic!( + "Support only 16-bit per sample. Given: {}", + spec.bits_per_sample + ); + } + + let samples: Vec = reader + .samples::() + .map(|s| s.unwrap() as f32 / 32768.0) + .collect(); + + (samples, spec.sample_rate as i32) +} diff --git a/src/lib.rs b/src/lib.rs index 0b283a4..082c7d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,9 @@ pub mod embedding_manager; pub mod keyword_spot; pub mod language_id; pub mod punctuate; +pub mod recognizer; pub mod speaker_id; +pub mod stream; pub mod vad; pub mod whisper; pub mod zipformer; diff --git a/src/recognizer.rs b/src/recognizer.rs new file mode 100644 index 0000000..de6ef92 --- /dev/null +++ b/src/recognizer.rs @@ -0,0 +1,40 @@ +//! Speech recognition with [Next-gen Kaldi]. +//! +//! [sherpa-onnx] is an open-source speech recognition framework for [Next-gen Kaldi]. +//! It depends only on [onnxruntime], supporting both streaming and non-streaming +//! speech recognition. +//! +//! It does not need to access the network during recognition and everything +//! runs locally. +//! +//! It supports a variety of platforms, such as Linux (x86_64, aarch64, arm), +//! Windows (x86_64, x86), macOS (x86_64, arm64), etc. +//! +//! Usage examples: +//! +//! 1. Real-time speech recognition from a microphone +//! +//! Please see +//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/real-time-speech-recognition-from-microphone +//! +//! 2. Decode files using a non-streaming model +//! +//! Please see +//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-decode-files +//! +//! 3. Decode files using a streaming model +//! +//! Please see +//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/streaming-decode-files +//! +//! 4. Convert text to speech using a non-streaming model +//! +//! Please see +//! https://github.com/k2-fsa/sherpa-onnx/tree/master/go-api-examples/non-streaming-tts +//! +//! [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx +//! [onnxruntime]: https://github.com/microsoft/onnxruntime +//! [Next-gen Kaldi]: https://github.com/k2-fsa/ + +pub mod offline_recognizer; +pub mod online_recognizer; diff --git a/src/recognizer/offline_recognizer.rs b/src/recognizer/offline_recognizer.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/recognizer/online_recognizer.rs b/src/recognizer/online_recognizer.rs new file mode 100644 index 0000000..1a67b5a --- /dev/null +++ b/src/recognizer/online_recognizer.rs @@ -0,0 +1,309 @@ +use crate::stream::online_stream::{InitialState, OnlineStream, State}; +use crate::utils; +use crate::utils::RawCStr; +use sherpa_rs_sys::{ + SherpaOnnxCreateOnlineStream, SherpaOnnxDecodeMultipleOnlineStreams, + SherpaOnnxDecodeOnlineStream, SherpaOnnxDestroyOnlineRecognizer, + SherpaOnnxDestroyOnlineRecognizerResult, SherpaOnnxFeatureConfig, + SherpaOnnxGetOnlineStreamResult, SherpaOnnxIsOnlineStreamReady, + SherpaOnnxOnlineCtcFstDecoderConfig, SherpaOnnxOnlineModelConfig, + SherpaOnnxOnlineParaformerModelConfig, SherpaOnnxOnlineRecognizer, + SherpaOnnxOnlineRecognizerConfig, SherpaOnnxOnlineStream, SherpaOnnxOnlineStreamIsEndpoint, + SherpaOnnxOnlineStreamReset, SherpaOnnxOnlineTransducerModelConfig, + SherpaOnnxOnlineZipformer2CtcModelConfig, +}; +use std::marker::PhantomData; + +/// Configuration for online/streaming transducer models +/// +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +/// to download pre-trained models +pub struct OnlineTransducerModelConfig { + pub encoder: String, // Path to the encoder model, e.g., encoder.onnx or encoder.int8.onnx + pub decoder: String, // Path to the decoder model. + pub joiner: String, // Path to the joiner model. +} + +/// Configuration for online/streaming paraformer models +/// +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +/// to download pre-trained models +pub struct OnlineParaformerModelConfig { + pub encoder: String, // Path to the encoder model, e.g., encoder.onnx or encoder.int8.onnx + pub decoder: String, // Path to the decoder model. +} + +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html +/// to download pre-trained models +pub struct OnlineZipformer2CtcModelConfig { + pub model: String, // Path to the onnx model +} + +/// Configuration for online/streaming models +/// +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +/// to download pre-trained models +pub struct OnlineModelConfig { + pub transducer: OnlineTransducerModelConfig, + pub paraformer: OnlineParaformerModelConfig, + pub zipformer2_ctc: OnlineZipformer2CtcModelConfig, + pub tokens: String, // Path to tokens.txt + pub num_threads: i32, // Number of threads to use for neural network computation + pub provider: String, // Optional. Valid values are: cpu, cuda, coreml + pub debug: i32, // 1 to show model meta information while loading it. + pub model_type: Option, // Optional. You can specify it for faster model initialization + pub modeling_unit: Option, // Optional. cjkchar, bpe, cjkchar+bpe + pub bpe_vocab: Option, // Optional. + pub tokens_buf: Option, // Optional. + pub tokens_buf_size: Option, // Optional. +} + +/// Configuration for the feature extractor +pub struct FeatureConfig { + /// Sample rate expected by the model. It is 16000 for all + /// pre-trained models provided by us + pub sample_rate: i32, + /// Feature dimension expected by the model. It is 80 for all + /// pre-trained models provided by us + pub feature_dim: i32, +} + +pub struct OnlineCtcFstDecoderConfig { + pub graph: String, + pub max_active: i32, +} + +/// Configuration for the online/streaming recognizer. +pub struct OnlineRecognizerConfig { + pub feat_config: FeatureConfig, + pub model_config: OnlineModelConfig, + + /// Valid decoding methods: greedy_search, modified_beam_search + pub decoding_method: String, + + /// Used only when DecodingMethod is modified_beam_search. It specifies + /// the maximum number of paths to keep during the search + pub max_active_paths: i32, + + pub enable_endpoint: i32, // 1 to enable endpoint detection. + + /// Please see + /// https://k2-fsa.github.io/sherpa/ncnn/endpoint.html + /// for the meaning of Rule1MinTrailingSilence, Rule2MinTrailingSilence + /// and Rule3MinUtteranceLength. + pub rule1_min_trailing_silence: f32, + pub rule2_min_trailing_silence: f32, + pub rule3_min_utterance_length: f32, + pub hotwords_file: String, + pub hotwords_score: f32, + pub blank_penalty: f32, + pub ctc_fst_decoder_config: OnlineCtcFstDecoderConfig, + pub rule_fsts: String, + pub rule_fars: String, + pub hotwords_buf: String, + pub hotwords_buf_size: i32, +} + +/// It contains the recognition result for a online stream. +pub struct OnlineRecognizerResult { + pub text: String, +} + +/// The online recognizer class. It wraps a pointer from C. +pub struct OnlineRecognizer { + pointer: *const SherpaOnnxOnlineRecognizer, +} + +impl Drop for OnlineRecognizer { + fn drop(&mut self) { + self.delete(); + } +} + +impl OnlineRecognizer { + /// The user is responsible to invoke [DeleteOnlineRecognizer]() to free + /// the returned recognizer to avoid memory leak + pub fn new(config: &OnlineRecognizerConfig) -> Self { + let transducer_encoder = RawCStr::new(&config.model_config.transducer.encoder); + let transducer_decoder = RawCStr::new(&config.model_config.transducer.decoder); + let transducer_joiner = RawCStr::new(&config.model_config.transducer.joiner); + let paraformer_encoder = RawCStr::new(&config.model_config.paraformer.encoder); + let paraformer_decoder = RawCStr::new(&config.model_config.paraformer.decoder); + let zipformer2_ctc_model = RawCStr::new(&config.model_config.zipformer2_ctc.model); + let tokens = RawCStr::new(&config.model_config.tokens); + let tokens_buf = config + .model_config + .tokens_buf + .as_ref() + .map_or_else(|| RawCStr::new(""), |tokens_buf| RawCStr::new(tokens_buf)); + let provider = RawCStr::new(&config.model_config.provider); + let mode_type = config + .model_config + .model_type + .as_ref() + .map_or_else(|| RawCStr::new(""), |model_type| RawCStr::new(model_type)); + let modeling_unit = config.model_config.modeling_unit.as_ref().map_or_else( + || RawCStr::new(""), + |modeling_unit| RawCStr::new(modeling_unit), + ); + let bpe_vocab = config + .model_config + .bpe_vocab + .as_ref() + .map_or_else(|| RawCStr::new(""), |bpe_vocab| RawCStr::new(bpe_vocab)); + let decoding_method = RawCStr::new(&config.decoding_method); + let hotwords_file = RawCStr::new(&config.hotwords_file); + let hotwords_buf = RawCStr::new(&config.hotwords_buf); + let rule_fsts = RawCStr::new(&config.rule_fsts); + let rule_fars = RawCStr::new(&config.rule_fars); + let graph = RawCStr::new(&config.ctc_fst_decoder_config.graph); + + let c_config = SherpaOnnxOnlineRecognizerConfig { + feat_config: SherpaOnnxFeatureConfig { + sample_rate: config.feat_config.sample_rate, + feature_dim: config.feat_config.feature_dim, + }, + model_config: SherpaOnnxOnlineModelConfig { + transducer: SherpaOnnxOnlineTransducerModelConfig { + encoder: transducer_encoder.as_ptr(), + decoder: transducer_decoder.as_ptr(), + joiner: transducer_joiner.as_ptr(), + }, + paraformer: SherpaOnnxOnlineParaformerModelConfig { + encoder: paraformer_encoder.as_ptr(), + decoder: paraformer_decoder.as_ptr(), + }, + zipformer2_ctc: SherpaOnnxOnlineZipformer2CtcModelConfig { + model: zipformer2_ctc_model.as_ptr(), + }, + tokens: tokens.as_ptr(), + tokens_buf: tokens_buf.as_ptr(), + tokens_buf_size: config.model_config.tokens_buf_size.unwrap_or(0), + num_threads: config.model_config.num_threads, + provider: provider.as_ptr(), + debug: config.model_config.debug, + model_type: mode_type.as_ptr(), + modeling_unit: modeling_unit.as_ptr(), + bpe_vocab: bpe_vocab.as_ptr(), + }, + decoding_method: decoding_method.as_ptr(), + max_active_paths: config.max_active_paths, + enable_endpoint: config.enable_endpoint, + rule1_min_trailing_silence: config.rule1_min_trailing_silence, + rule2_min_trailing_silence: config.rule2_min_trailing_silence, + rule3_min_utterance_length: config.rule3_min_utterance_length, + hotwords_file: hotwords_file.as_ptr(), + hotwords_buf: hotwords_buf.as_ptr(), + hotwords_buf_size: config.hotwords_buf_size, + hotwords_score: config.hotwords_score, + blank_penalty: config.blank_penalty, + rule_fsts: rule_fsts.as_ptr(), + rule_fars: rule_fars.as_ptr(), + ctc_fst_decoder_config: SherpaOnnxOnlineCtcFstDecoderConfig { + graph: graph.as_ptr(), + max_active: config.ctc_fst_decoder_config.max_active, + }, + }; + + let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOnlineRecognizer(&c_config) }; + + OnlineRecognizer { + pointer: recognizer, + } + } + + /// Free the internal pointer inside the recognizer to avoid memory leak. + fn delete(&mut self) { + unsafe { + SherpaOnnxDestroyOnlineRecognizer(self.pointer); + } + } + + /// The user is responsible to invoke [DeleteOnlineStream]() to free + /// the returned stream to avoid memory leak + pub fn new_stream(&self) -> OnlineStream { + let stream = unsafe { SherpaOnnxCreateOnlineStream(self.pointer) }; + OnlineStream { + pointer: stream, + _marker: PhantomData, + } + } + + /// Check whether the stream has enough feature frames for decoding. + /// Return true if this stream is ready for decoding. Return false otherwise. + /// + /// You will usually use it like below: + /// + /// for recognizer.IsReady(s) { + /// recognizer.Decode(s) + /// } + pub fn is_ready(&self, stream: &OnlineStream) -> bool { + unsafe { SherpaOnnxIsOnlineStreamReady(self.pointer, stream.pointer) == 1 } + } + + /// Return true if an endpoint is detected. + /// + /// You usually use it like below: + /// + /// if recognizer.IsEndpoint(s) { + /// // do your own stuff after detecting an endpoint + /// + /// recognizer.Reset(s) + /// } + pub fn is_endpoint(&self, stream: &OnlineStream) -> bool { + unsafe { SherpaOnnxOnlineStreamIsEndpoint(self.pointer, stream.pointer) == 1 } + } + + /// After calling this function, the internal neural network model states + /// are reset and IsEndpoint(s) would return false. GetResult(s) would also + /// return an empty string. + pub fn reset(&self, stream: &OnlineStream) { + unsafe { + SherpaOnnxOnlineStreamReset(self.pointer, stream.pointer); + } + } + + /// Decode the stream. Before calling this function, you have to ensure + /// that recognizer.IsReady(s) returns true. Otherwise, you will be SAD. + /// + /// You usually use it like below: + /// + /// for recognizer.IsReady(s) { + /// recognizer.Decode(s) + /// } + pub fn decode(&self, stream: &OnlineStream) { + unsafe { + SherpaOnnxDecodeOnlineStream(self.pointer, stream.pointer); + } + } + + /// Decode multiple streams in parallel, i.e., in batch. + /// You have to ensure that each stream is ready for decoding. Otherwise, + /// you will be SAD. + pub fn decode_streams(&self, streams: &[OnlineStream]) { + let mut c_streams: Vec<*const SherpaOnnxOnlineStream> = + streams.iter().map(|s| s.pointer).collect(); + unsafe { + SherpaOnnxDecodeMultipleOnlineStreams( + self.pointer, + c_streams.as_mut_ptr(), + c_streams.len() as i32, + ); + } + } + + /// Get the current result of stream since the last invoke of Reset() + pub fn get_result(&self, stream: &OnlineStream) -> OnlineRecognizerResult { + let result = unsafe { SherpaOnnxGetOnlineStreamResult(self.pointer, stream.pointer) }; + let text = utils::cstr_to_string((unsafe { *result }).text); + unsafe { + SherpaOnnxDestroyOnlineRecognizerResult(result); + } + OnlineRecognizerResult { text } + } +} diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..f8c46fc --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,2 @@ +pub mod offline_stream; +pub mod online_stream; diff --git a/src/stream/offline_stream.rs b/src/stream/offline_stream.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/stream/online_stream.rs b/src/stream/online_stream.rs new file mode 100644 index 0000000..b986129 --- /dev/null +++ b/src/stream/online_stream.rs @@ -0,0 +1,78 @@ +use crate::recognizer::online_recognizer::OnlineRecognizer; +use sherpa_rs_sys::{ + SherpaOnnxDestroyOnlineStream, SherpaOnnxOnlineStream, SherpaOnnxOnlineStreamAcceptWaveform, + SherpaOnnxOnlineStreamInputFinished, +}; +use std::marker::PhantomData; + +/// The online stream class. It wraps a pointer from C. +pub struct OnlineStream { + pub(crate) pointer: *const SherpaOnnxOnlineStream, + pub(crate) _marker: PhantomData, +} + +pub trait State {} + +pub struct InitialState; +pub struct InputFinishedCalledState; + +impl State for InitialState {} + +impl State for InputFinishedCalledState {} + +/// Delete the internal pointer inside the stream to avoid memory leak. +impl Drop for OnlineStream { + fn drop(&mut self) { + self.delete(); + } +} + +impl OnlineStream { + /// Signal that there will be no incoming audio samples. + /// After calling this function, you cannot call [OnlineStream.AcceptWaveform] any longer. + /// + /// The main purpose of this function is to flush the remaining audio samples + /// buffered inside for feature extraction. + pub fn input_finished(self) -> OnlineStream { + unsafe { + SherpaOnnxOnlineStreamInputFinished(self.pointer); + } + OnlineStream { + pointer: self.pointer, + _marker: PhantomData, + } + } + + /// Delete the internal pointer inside the stream to avoid memory leak. + fn delete(&mut self) { + unsafe { + SherpaOnnxDestroyOnlineStream(self.pointer); + } + } +} + +impl OnlineStream { + /// The user is responsible to invoke [DeleteOnlineStream]() to free + /// the returned stream to avoid memory leak + pub fn new(recognizer: &OnlineRecognizer) -> Self { + recognizer.new_stream() + } + + /// Input audio samples for the stream. + /// + /// sampleRate is the actual sample rate of the input audio samples. If it + /// is different from the sample rate expected by the feature extractor, we will + /// do resampling inside. + /// + /// samples contains audio samples. Each sample is in the range [-1, 1] + pub fn accept_waveform(&self, sample_rate: i32, samples: &[f32]) { + unsafe { + SherpaOnnxOnlineStreamAcceptWaveform( + self.pointer, + sample_rate, + samples.as_ptr(), + samples.len() as i32, + ); + } + } +} diff --git a/src/utils.rs b/src/utils.rs index 8a4f0fb..f03c1cd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,6 +19,7 @@ impl RawCStr { /// This function only returns the raw pointer and does not transfer ownership. /// The pointer remains valid as long as the `CStr` instance exists. /// Be cautious not to deallocate or modify the pointer after using `CStr::new`. + #[inline] pub fn as_ptr(&self) -> *const i8 { self.ptr } From 18224b0ef36bcf9f5c745cdeae9047f40f48fee0 Mon Sep 17 00:00:00 2001 From: gonghuijun <303690073@qq.com> Date: Mon, 28 Oct 2024 23:20:29 +0800 Subject: [PATCH 4/4] Submit streaming code and examples --- examples/streaming_decode_files.rs | 6 +- src/common_config.rs | 9 +++ src/lib.rs | 1 + src/recognizer/offline_recognizer.rs | 114 +++++++++++++++++++++++++++ src/recognizer/online_recognizer.rs | 11 +-- src/stream/offline_stream.rs | 6 ++ 6 files changed, 134 insertions(+), 13 deletions(-) create mode 100644 src/common_config.rs diff --git a/examples/streaming_decode_files.rs b/examples/streaming_decode_files.rs index 36bb18a..fcc1b94 100644 --- a/examples/streaming_decode_files.rs +++ b/examples/streaming_decode_files.rs @@ -1,8 +1,8 @@ use clap::{arg, Parser}; +use sherpa_rs::common_config::FeatureConfig; use sherpa_rs::recognizer::online_recognizer::{ - FeatureConfig, OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, - OnlineRecognizer, OnlineRecognizerConfig, OnlineTransducerModelConfig, - OnlineZipformer2CtcModelConfig, + OnlineCtcFstDecoderConfig, OnlineModelConfig, OnlineParaformerModelConfig, OnlineRecognizer, + OnlineRecognizerConfig, OnlineTransducerModelConfig, OnlineZipformer2CtcModelConfig, }; use sherpa_rs::stream::online_stream::OnlineStream; diff --git a/src/common_config.rs b/src/common_config.rs new file mode 100644 index 0000000..e91409e --- /dev/null +++ b/src/common_config.rs @@ -0,0 +1,9 @@ +/// Configuration for the feature extractor +pub struct FeatureConfig { + /// Sample rate expected by the model. It is 16000 for all + /// pre-trained models provided by us + pub sample_rate: i32, + /// Feature dimension expected by the model. It is 80 for all + /// pre-trained models provided by us + pub feature_dim: i32, +} diff --git a/src/lib.rs b/src/lib.rs index 082c7d7..eee18f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod audio_tag; +pub mod common_config; pub mod diarize; pub mod embedding_manager; pub mod keyword_spot; diff --git a/src/recognizer/offline_recognizer.rs b/src/recognizer/offline_recognizer.rs index e69de29..214cb59 100644 --- a/src/recognizer/offline_recognizer.rs +++ b/src/recognizer/offline_recognizer.rs @@ -0,0 +1,114 @@ +use crate::common_config::FeatureConfig; +use sherpa_rs_sys::SherpaOnnxOfflineRecognizer; + +/// Configuration for offline/non-streaming transducer. +/// +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html +/// to download pre-trained models +struct OfflineTransducerModelConfig { + encoder: String, // Path to the encoder model, i.e., encoder.onnx or encoder.int8.onnx + decoder: String, // Path to the decoder model + joiner: String, // Path to the joiner model +} + +/// Configuration for offline/non-streaming paraformer. +/// +/// please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html +/// to download pre-trained models +struct OfflineParaformerModelConfig { + model: String, // Path to the model, e.g., model.onnx or model.int8.onnx +} + +/// Configuration for offline/non-streaming NeMo CTC models. +/// +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html +/// to download pre-trained models +struct OfflineNemoEncDecCtcModelConfig { + model: String, // Path to the model, e.g., model.onnx or model.int8.onnx +} + +struct OfflineWhisperModelConfig { + encoder: String, + decoder: String, + language: String, + task: String, + tail_paddings: i32, +} + +struct OfflineTdnnModelConfig { + model: String, +} + +struct OfflineSenseVoiceModelConfig { + model: String, + language: String, + use_inverse_text_normalization: i32, +} + +/// Configuration for offline LM. +struct OfflineLMConfig { + model: String, // Path to the model + scale: f32, // scale for LM score +} + +struct OfflineModelConfig { + transducer: OfflineTransducerModelConfig, + paraformer: OfflineParaformerModelConfig, + nemo_ctc: OfflineNemoEncDecCtcModelConfig, + whisper: OfflineWhisperModelConfig, + tdnn: OfflineTdnnModelConfig, + sense_voice: OfflineSenseVoiceModelConfig, + tokens: String, // Path to tokens.txt + + // Number of threads to use for neural network computation + num_threads: i32, + + // 1 to print model meta information while loading + debug: i32, + + // Optional. Valid values: cpu, cuda, coreml + provider: String, + + // Optional. Specify it for faster model initialization. + model_type: String, + + modeling_unit: String, // Optional. cjkchar, bpe, cjkchar+bpe + bpe_vocab: String, // Optional. + tele_speech_ctc: String, // Optional. +} + +/// Configuration for the offline/non-streaming recognizer. +struct OfflineRecognizerConfig { + feat_config: FeatureConfig, + model_config: OfflineModelConfig, + lm_config: OfflineLMConfig, + + // Valid decoding method: greedy_search, modified_beam_search + decoding_method: String, + + // Used only when DecodingMethod is modified_beam_search. + max_active_paths: i32, + hotwords_file: String, + hotwords_score: f32, + blank_penalty: f32, + rule_fsts: String, + rule_fars: String, +} + +/// It wraps a pointer from C +struct OfflineRecognizer { + pointer: *const SherpaOnnxOfflineRecognizer, +} + +/// It contains recognition result of an offline stream. +struct OfflineRecognizerResult { + text: String, + tokens: Vec, + timestamps: Vec, + lang: String, + emotion: String, + event: String, +} diff --git a/src/recognizer/online_recognizer.rs b/src/recognizer/online_recognizer.rs index 1a67b5a..043a9d0 100644 --- a/src/recognizer/online_recognizer.rs +++ b/src/recognizer/online_recognizer.rs @@ -1,3 +1,4 @@ +use crate::common_config::FeatureConfig; use crate::stream::online_stream::{InitialState, OnlineStream, State}; use crate::utils; use crate::utils::RawCStr; @@ -63,16 +64,6 @@ pub struct OnlineModelConfig { pub tokens_buf_size: Option, // Optional. } -/// Configuration for the feature extractor -pub struct FeatureConfig { - /// Sample rate expected by the model. It is 16000 for all - /// pre-trained models provided by us - pub sample_rate: i32, - /// Feature dimension expected by the model. It is 80 for all - /// pre-trained models provided by us - pub feature_dim: i32, -} - pub struct OnlineCtcFstDecoderConfig { pub graph: String, pub max_active: i32, diff --git a/src/stream/offline_stream.rs b/src/stream/offline_stream.rs index e69de29..5349e83 100644 --- a/src/stream/offline_stream.rs +++ b/src/stream/offline_stream.rs @@ -0,0 +1,6 @@ +use sherpa_rs_sys::SherpaOnnxOfflineStream; + +/// It wraps a pointer from C +struct OfflineStream { + pointer: *const SherpaOnnxOfflineStream, +}