diff --git a/data/models/silero-vad/silero_vad.onnx b/data/models/silero-vad/silero_vad.onnx index e6db48d..b3e3a90 100644 Binary files a/data/models/silero-vad/silero_vad.onnx and b/data/models/silero-vad/silero_vad.onnx differ diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 96de825..8a640f3 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -48,6 +48,7 @@ struct transcription_filter_data { /* Resampler */ audio_resampler_t *resampler_to_whisper; + struct circlebuf resampled_buffer; /* whisper */ std::string whisper_model_path; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index b6d7cc2..3683c18 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -108,6 +108,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ // calculate timestamp offset from the start of the stream info.timestamp_offset_ns = now_ns() - gf->start_timestamp_ms * 1000000; circlebuf_push_back(&gf->info_buffer, &info, sizeof(info)); + gf->wshiper_thread_cv.notify_one(); } return audio; @@ -154,6 +155,8 @@ void transcription_filter_destroy(void *data) } circlebuf_free(&gf->info_buffer); + circlebuf_free(&gf->resampled_buffer); + if (gf->captions_monitor.isEnabled()) { gf->captions_monitor.stopThread(); } @@ -443,6 +446,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) } circlebuf_init(&gf->info_buffer); circlebuf_init(&gf->whisper_buffer); + circlebuf_init(&gf->resampled_buffer); // allocate copy buffers gf->copy_buffers[0] = diff --git a/src/whisper-utils/silero-vad-onnx.cpp b/src/whisper-utils/silero-vad-onnx.cpp index 857649d..078e47c 100644 --- a/src/whisper-utils/silero-vad-onnx.cpp +++ b/src/whisper-utils/silero-vad-onnx.cpp @@ -92,11 +92,11 @@ void VadIterator::init_onnx_model(const SileroString &model_path) session = std::make_shared(env, model_path.c_str(), session_options); }; -void VadIterator::reset_states(bool reset_hc) +void VadIterator::reset_states(bool reset_state) { - if (reset_hc) { - std::memset(_h.data(), 0.0f, _h.size() * sizeof(float)); - std::memset(_c.data(), 0.0f, _c.size() * sizeof(float)); + if (reset_state) { + // Call reset before each audio start + std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); triggered = false; } temp_end = 0; @@ -115,19 +115,16 @@ float VadIterator::predict_one(const std::vector &data) input.assign(data.begin(), data.end()); Ort::Value input_ort = Ort::Value::CreateTensor(memory_info, input.data(), input.size(), input_node_dims, 2); + Ort::Value state_ort = Ort::Value::CreateTensor( + memory_info, _state.data(), _state.size(), state_node_dims, 3); Ort::Value sr_ort = Ort::Value::CreateTensor(memory_info, sr.data(), sr.size(), sr_node_dims, 1); - Ort::Value h_ort = - Ort::Value::CreateTensor(memory_info, _h.data(), _h.size(), hc_node_dims, 3); - Ort::Value c_ort = - Ort::Value::CreateTensor(memory_info, _c.data(), _c.size(), hc_node_dims, 3); // Clear and add inputs ort_inputs.clear(); ort_inputs.emplace_back(std::move(input_ort)); + ort_inputs.emplace_back(std::move(state_ort)); ort_inputs.emplace_back(std::move(sr_ort)); - ort_inputs.emplace_back(std::move(h_ort)); - ort_inputs.emplace_back(std::move(c_ort)); // Infer ort_outputs = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), @@ -136,10 +133,8 @@ float VadIterator::predict_one(const std::vector &data) // Output probability & update h,c recursively float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; - float *hn = ort_outputs[1].GetTensorMutableData(); - std::memcpy(_h.data(), hn, size_hc * sizeof(float)); - float *cn = ort_outputs[2].GetTensorMutableData(); - std::memcpy(_c.data(), cn, size_hc * sizeof(float)); + float *stateN = ort_outputs[1].GetTensorMutableData(); + std::memcpy(_state.data(), stateN, size_state * sizeof(float)); return speech_prob; } @@ -264,9 +259,9 @@ void VadIterator::predict(const std::vector &data) } }; -void VadIterator::process(const std::vector &input_wav, bool reset_hc) +void VadIterator::process(const std::vector &input_wav, bool reset_state) { - reset_states(reset_hc); + reset_states(reset_state); audio_length_samples = (int)input_wav.size(); @@ -290,7 +285,7 @@ void VadIterator::process(const std::vector &input_wav, bool reset_hc) void VadIterator::process(const std::vector &input_wav, std::vector &output_wav) { - process(input_wav, true); + process(input_wav); collect_chunks(input_wav, output_wav); } @@ -352,8 +347,7 @@ VadIterator::VadIterator(const SileroString &ModelPath, int Sample_rate, int win input_node_dims[0] = 1; input_node_dims[1] = window_size_samples; - _h.resize(size_hc); - _c.resize(size_hc); + _state.resize(size_state); sr.resize(1); sr[0] = sample_rate; }; diff --git a/src/whisper-utils/silero-vad-onnx.h b/src/whisper-utils/silero-vad-onnx.h index 3d4de6c..cb284a4 100644 --- a/src/whisper-utils/silero-vad-onnx.h +++ b/src/whisper-utils/silero-vad-onnx.h @@ -43,18 +43,20 @@ class VadIterator { private: void init_engine_threads(int inter_threads, int intra_threads); void init_onnx_model(const SileroString &model_path); - void reset_states(bool reset_hc); + void reset_states(bool reset_state); float predict_one(const std::vector &data); void predict(const std::vector &data); public: - void process(const std::vector &input_wav, bool reset_hc = true); + void process(const std::vector &input_wav, bool reset_state = true); void process(const std::vector &input_wav, std::vector &output_wav); void collect_chunks(const std::vector &input_wav, std::vector &output_wav); const std::vector get_speech_timestamps() const; void drop_chunks(const std::vector &input_wav, std::vector &output_wav); void set_threshold(float threshold_) { this->threshold = threshold_; } + int64_t get_window_size_samples() const { return window_size_samples; } + private: // model config int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. @@ -84,27 +86,26 @@ class VadIterator { // Inputs std::vector ort_inputs; - std::vector input_node_names = {"input", "sr", "h", "c"}; + std::vector input_node_names = {"input", "state", "sr"}; std::vector input; + unsigned int size_state = 2 * 1 * 128; // It's FIXED. + std::vector _state; std::vector sr; - unsigned int size_hc = 2 * 1 * 64; // It's FIXED. - std::vector _h; - std::vector _c; int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; const int64_t sr_node_dims[1] = {1}; - const int64_t hc_node_dims[3] = {2, 1, 64}; // Outputs std::vector ort_outputs; - std::vector output_node_names = {"output", "hn", "cn"}; + std::vector output_node_names = {"output", "stateN"}; public: // Construction VadIterator(const SileroString &ModelPath, int Sample_rate = 16000, - int windows_frame_size = 64, float Threshold = 0.5, - int min_silence_duration_ms = 0, int speech_pad_ms = 64, - int min_speech_duration_ms = 64, + int windows_frame_size = 32, float Threshold = 0.5, + int min_silence_duration_ms = 0, int speech_pad_ms = 32, + int min_speech_duration_ms = 32, float max_speech_duration_s = std::numeric_limits::infinity()); // Default constructor diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index b290e8f..7239b04 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -2,6 +2,8 @@ #include +#include + #include "plugin-support.h" #include "transcription-filter-data.h" #include "whisper-processing.h" @@ -392,22 +394,43 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v num_frames_from_infos, overlap_size); gf->last_num_frames = num_frames_from_infos + overlap_size; - // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, - (uint32_t)num_frames_from_infos); + { + // resample to 16kHz + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; + uint64_t ts_offset; + { + ProfileScope("resample"); + audio_resampler_resample(gf->resampler_to_whisper, + (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, + (const uint8_t **)gf->copy_buffers, + (uint32_t)num_frames_from_infos); + } + + obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); + circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], + resampled_16khz_frames * sizeof(float)); + } + + if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float))) + return last_vad_state; - obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", (int)gf->channels, - (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); + size_t len = + gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float)); - std::vector vad_input(resampled_16khz[0], - resampled_16khz[0] + resampled_16khz_frames); - gf->vad->process(vad_input, false); + std::vector vad_input; + vad_input.resize(len * gf->vad->get_window_size_samples()); + circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad", vad_input.size()); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, !last_vad_state.vad_on); + } const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; @@ -417,8 +440,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v std::vector stamps = gf->vad->get_speech_timestamps(); if (stamps.size() == 0) { - obs_log(gf->log_level, "VAD detected no speech in %d frames", - resampled_16khz_frames); + obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); if (last_vad_state.vad_on) { obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, @@ -428,7 +450,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v } if (gf->enable_audio_chunks_callback) { - audio_chunk_callback(gf, resampled_16khz[0], resampled_16khz_frames, + audio_chunk_callback(gf, vad_input.data(), vad_input.size(), VAD_STATE_IS_OFF, {DETECTION_RESULT_SILENCE, "[silence]", @@ -452,16 +474,16 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v } int end_frame = stamps[i].end; - if (i == stamps.size() - 1 && stamps[i].end < (int)resampled_16khz_frames) { + if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { // take at least 100ms of audio after the last speech segment, if available end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, - (int)resampled_16khz_frames); + (int)vad_input.size()); } const int number_of_frames = end_frame - start_frame; // push the data into gf-whisper_buffer - circlebuf_push_back(&gf->whisper_buffer, resampled_16khz[0] + start_frame, + circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, number_of_frames * sizeof(float)); obs_log(gf->log_level, @@ -472,7 +494,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); // segment "end" is in the middle of the buffer, send it to inference - if (stamps[i].end < (int)resampled_16khz_frames) { + if (stamps[i].end < (int)vad_input.size()) { // new "ending" segment (not up to the end of the buffer) obs_log(gf->log_level, "VAD segment end -> send to inference"); // find the end timestamp of the segment @@ -545,30 +567,24 @@ void whisper_loop(void *data) obs_log(gf->log_level, "Starting whisper thread"); vad_state current_vad_state = {false, 0, 0, 0}; - // 500 ms worth of audio is needed for VAD segmentation - uint32_t min_num_bytes_for_vad = (gf->sample_rate / 2) * sizeof(float); + + const char *whisper_loop_name = "Whisper loop"; + profile_register_root(whisper_loop_name, 50 * 1000 * 1000); // Thread main loop while (true) { + ProfileScope(whisper_loop_name); { + ProfileScope("lock whisper ctx"); std::lock_guard lock(gf->whisper_ctx_mutex); + ProfileScope("locked whisper ctx"); if (gf->whisper_context == nullptr) { obs_log(LOG_WARNING, "Whisper context is null, exiting thread"); break; } } - uint32_t num_bytes_on_input = 0; - { - // scoped lock the buffer mutex - std::lock_guard lock(gf->whisper_buf_mutex); - num_bytes_on_input = (uint32_t)gf->input_buffers[0].size; - } - - // only run vad segmentation if there are at least 500 ms of audio in the buffer - if (num_bytes_on_input > min_num_bytes_for_vad) { - current_vad_state = vad_based_segmentation(gf, current_vad_state); - } + current_vad_state = vad_based_segmentation(gf, current_vad_state); if (!gf->cleared_last_sub) { // check if we should clear the current sub depending on the minimum subtitle duration