Skip to content

Commit

Permalink
Upgrade silero vad v5 (and some other changes) (#148)
Browse files Browse the repository at this point in the history
* Add accessor for VAD window size in samples

* Feed buffered audio data to VAD in proper window sizes

* Wake whisper thread whenever audio is received

* Update silero VAD to v5

* Only reset VAD state between chunks of activity
  • Loading branch information
palana authored Aug 2, 2024
1 parent a173e22 commit 0592fa7
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 64 deletions.
Binary file modified data/models/silero-vad/silero_vad.onnx
Binary file not shown.
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct transcription_filter_data {

/* Resampler */
audio_resampler_t *resampler_to_whisper;
struct circlebuf resampled_buffer;

/* whisper */
std::string whisper_model_path;
Expand Down
4 changes: 4 additions & 0 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
// calculate timestamp offset from the start of the stream
info.timestamp_offset_ns = now_ns() - gf->start_timestamp_ms * 1000000;
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
gf->wshiper_thread_cv.notify_one();
}

return audio;
Expand Down Expand Up @@ -154,6 +155,8 @@ void transcription_filter_destroy(void *data)
}
circlebuf_free(&gf->info_buffer);

circlebuf_free(&gf->resampled_buffer);

if (gf->captions_monitor.isEnabled()) {
gf->captions_monitor.stopThread();
}
Expand Down Expand Up @@ -443,6 +446,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
}
circlebuf_init(&gf->info_buffer);
circlebuf_init(&gf->whisper_buffer);
circlebuf_init(&gf->resampled_buffer);

// allocate copy buffers
gf->copy_buffers[0] =
Expand Down
32 changes: 13 additions & 19 deletions src/whisper-utils/silero-vad-onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};

void VadIterator::reset_states(bool reset_hc)
void VadIterator::reset_states(bool reset_state)
{
if (reset_hc) {
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
if (reset_state) {
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false;
}
temp_end = 0;
Expand All @@ -115,19 +115,16 @@ float VadIterator::predict_one(const std::vector<float> &data)
input.assign(data.begin(), data.end());
Ort::Value input_ort = Ort::Value::CreateTensor<float>(memory_info, input.data(),
input.size(), input_node_dims, 2);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(memory_info, sr.data(), sr.size(),
sr_node_dims, 1);
Ort::Value h_ort =
Ort::Value::CreateTensor<float>(memory_info, _h.data(), _h.size(), hc_node_dims, 3);
Ort::Value c_ort =
Ort::Value::CreateTensor<float>(memory_info, _c.data(), _c.size(), hc_node_dims, 3);

// Clear and add inputs
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort));
ort_inputs.emplace_back(std::move(h_ort));
ort_inputs.emplace_back(std::move(c_ort));

// Infer
ort_outputs = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(),
Expand All @@ -136,10 +133,8 @@ float VadIterator::predict_one(const std::vector<float> &data)

// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *hn = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
float *cn = ort_outputs[2].GetTensorMutableData<float>();
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float));

return speech_prob;
}
Expand Down Expand Up @@ -264,9 +259,9 @@ void VadIterator::predict(const std::vector<float> &data)
}
};

void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
void VadIterator::process(const std::vector<float> &input_wav, bool reset_state)
{
reset_states(reset_hc);
reset_states(reset_state);

audio_length_samples = (int)input_wav.size();

Expand All @@ -290,7 +285,7 @@ void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)

void VadIterator::process(const std::vector<float> &input_wav, std::vector<float> &output_wav)
{
process(input_wav, true);
process(input_wav);
collect_chunks(input_wav, output_wav);
}

Expand Down Expand Up @@ -352,8 +347,7 @@ VadIterator::VadIterator(const SileroString &ModelPath, int Sample_rate, int win
input_node_dims[0] = 1;
input_node_dims[1] = window_size_samples;

_h.resize(size_hc);
_c.resize(size_hc);
_state.resize(size_state);
sr.resize(1);
sr[0] = sample_rate;
};
23 changes: 12 additions & 11 deletions src/whisper-utils/silero-vad-onnx.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,20 @@ class VadIterator {
private:
void init_engine_threads(int inter_threads, int intra_threads);
void init_onnx_model(const SileroString &model_path);
void reset_states(bool reset_hc);
void reset_states(bool reset_state);
float predict_one(const std::vector<float> &data);
void predict(const std::vector<float> &data);

public:
void process(const std::vector<float> &input_wav, bool reset_hc = true);
void process(const std::vector<float> &input_wav, bool reset_state = true);
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const;
void drop_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void set_threshold(float threshold_) { this->threshold = threshold_; }

int64_t get_window_size_samples() const { return window_size_samples; }

private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
Expand Down Expand Up @@ -84,27 +86,26 @@ class VadIterator {
// Inputs
std::vector<Ort::Value> ort_inputs;

std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
std::vector<float> _h;
std::vector<float> _c;

int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
const int64_t hc_node_dims[3] = {2, 1, 64};

// Outputs
std::vector<Ort::Value> ort_outputs;
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
std::vector<const char *> output_node_names = {"output", "stateN"};

public:
// Construction
VadIterator(const SileroString &ModelPath, int Sample_rate = 16000,
int windows_frame_size = 64, float Threshold = 0.5,
int min_silence_duration_ms = 0, int speech_pad_ms = 64,
int min_speech_duration_ms = 64,
int windows_frame_size = 32, float Threshold = 0.5,
int min_silence_duration_ms = 0, int speech_pad_ms = 32,
int min_speech_duration_ms = 32,
float max_speech_duration_s = std::numeric_limits<float>::infinity());

// Default constructor
Expand Down
84 changes: 50 additions & 34 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <obs-module.h>

#include <util/profiler.hpp>

#include "plugin-support.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
Expand Down Expand Up @@ -392,22 +394,43 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
num_frames_from_infos, overlap_size);
gf->last_num_frames = num_frames_from_infos + overlap_size;

// resample to 16kHz
float *resampled_16khz[MAX_PREPROC_CHANNELS];
uint32_t resampled_16khz_frames;
uint64_t ts_offset;
audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz,
&resampled_16khz_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers,
(uint32_t)num_frames_from_infos);
{
// resample to 16kHz
float *resampled_16khz[MAX_PREPROC_CHANNELS];
uint32_t resampled_16khz_frames;
uint64_t ts_offset;
{
ProfileScope("resample");
audio_resampler_resample(gf->resampler_to_whisper,
(uint8_t **)resampled_16khz,
&resampled_16khz_frames, &ts_offset,
(const uint8_t **)gf->copy_buffers,
(uint32_t)num_frames_from_infos);
}

obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms",
(int)gf->channels, (int)resampled_16khz_frames,
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0],
resampled_16khz_frames * sizeof(float));
}

if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float)))
return last_vad_state;

obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", (int)gf->channels,
(int)resampled_16khz_frames,
(float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f);
size_t len =
gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float));

std::vector<float> vad_input(resampled_16khz[0],
resampled_16khz[0] + resampled_16khz_frames);
gf->vad->process(vad_input, false);
std::vector<float> vad_input;
vad_input.resize(len * gf->vad->get_window_size_samples());
circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(),
vad_input.size() * sizeof(float));

obs_log(gf->log_level, "sending %d frames to vad", vad_input.size());
{
ProfileScope("vad->process");
gf->vad->process(vad_input, !last_vad_state.vad_on);
}

const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000;
const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000;
Expand All @@ -417,8 +440,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v

std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
obs_log(gf->log_level, "VAD detected no speech in %d frames",
resampled_16khz_frames);
obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size());
if (last_vad_state.vad_on) {
obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference");
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
Expand All @@ -428,7 +450,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
}

if (gf->enable_audio_chunks_callback) {
audio_chunk_callback(gf, resampled_16khz[0], resampled_16khz_frames,
audio_chunk_callback(gf, vad_input.data(), vad_input.size(),
VAD_STATE_IS_OFF,
{DETECTION_RESULT_SILENCE,
"[silence]",
Expand All @@ -452,16 +474,16 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
}

int end_frame = stamps[i].end;
if (i == stamps.size() - 1 && stamps[i].end < (int)resampled_16khz_frames) {
if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) {
// take at least 100ms of audio after the last speech segment, if available
end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10,
(int)resampled_16khz_frames);
(int)vad_input.size());
}

const int number_of_frames = end_frame - start_frame;

// push the data into gf-whisper_buffer
circlebuf_push_back(&gf->whisper_buffer, resampled_16khz[0] + start_frame,
circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame,
number_of_frames * sizeof(float));

obs_log(gf->log_level,
Expand All @@ -472,7 +494,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE);

// segment "end" is in the middle of the buffer, send it to inference
if (stamps[i].end < (int)resampled_16khz_frames) {
if (stamps[i].end < (int)vad_input.size()) {
// new "ending" segment (not up to the end of the buffer)
obs_log(gf->log_level, "VAD segment end -> send to inference");
// find the end timestamp of the segment
Expand Down Expand Up @@ -545,30 +567,24 @@ void whisper_loop(void *data)
obs_log(gf->log_level, "Starting whisper thread");

vad_state current_vad_state = {false, 0, 0, 0};
// 500 ms worth of audio is needed for VAD segmentation
uint32_t min_num_bytes_for_vad = (gf->sample_rate / 2) * sizeof(float);

const char *whisper_loop_name = "Whisper loop";
profile_register_root(whisper_loop_name, 50 * 1000 * 1000);

// Thread main loop
while (true) {
ProfileScope(whisper_loop_name);
{
ProfileScope("lock whisper ctx");
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
ProfileScope("locked whisper ctx");
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
break;
}
}

uint32_t num_bytes_on_input = 0;
{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
num_bytes_on_input = (uint32_t)gf->input_buffers[0].size;
}

// only run vad segmentation if there are at least 500 ms of audio in the buffer
if (num_bytes_on_input > min_num_bytes_for_vad) {
current_vad_state = vad_based_segmentation(gf, current_vad_state);
}
current_vad_state = vad_based_segmentation(gf, current_vad_state);

if (!gf->cleared_last_sub) {
// check if we should clear the current sub depending on the minimum subtitle duration
Expand Down

0 comments on commit 0592fa7

Please sign in to comment.