From dcf368d38f60064889167f7f1e83e374f1b22445 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 23 Aug 2024 16:17:34 -0400 Subject: [PATCH] feat: Add max_sub_duration field to transcription filter data --- src/transcription-filter-callbacks.cpp | 28 ++++--- src/transcription-filter-data.h | 2 + src/transcription-filter-properties.cpp | 29 ++++++- src/transcription-filter.cpp | 2 + src/whisper-utils/token-buffer-thread.cpp | 92 +++++++++++++++++------ src/whisper-utils/token-buffer-thread.h | 28 ++++++- src/whisper-utils/vad-processing.cpp | 50 ++++++++++-- src/whisper-utils/vad-processing.h | 1 + src/whisper-utils/whisper-processing.cpp | 2 +- src/whisper-utils/whisper-utils.cpp | 21 +----- 10 files changed, 189 insertions(+), 66 deletions(-) diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index f5c2209..22e240b 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -199,11 +199,6 @@ void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &resultIn) { DetectionResultWithText result = resultIn; - if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || - result.result == DETECTION_RESULT_PARTIAL)) { - gf->last_sub_render_time = now_ms(); - gf->cleared_last_sub = false; - } std::string str_copy = result.text; @@ -243,10 +238,12 @@ void set_text_callback(struct transcription_filter_data *gf, str_copy = translated_sentence; } else { if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - // buffered output - add the sentence to the monitor - gf->translation_monitor.addSentence(translated_sentence); - } + // buffered output - add the sentence to the monitor + gf->translation_monitor.addSentenceFromStdString( + translated_sentence, + get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->translation_output, translated_sentence, @@ -256,9 +253,10 @@ void set_text_callback(struct transcription_filter_data *gf, } if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - gf->captions_monitor.addSentence(str_copy); - } + gf->captions_monitor.addSentenceFromStdString( + str_copy, get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->text_source_name, str_copy, gf); @@ -273,6 +271,12 @@ void set_text_callback(struct transcription_filter_data *gf, result.result == DETECTION_RESULT_SPEECH) { send_sentence_to_file(gf, result, str_copy, translated_sentence); } + + if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || + result.result == DETECTION_RESULT_PARTIAL)) { + gf->last_sub_render_time = now_ms(); + gf->cleared_last_sub = false; + } }; void recording_state_callback(enum obs_frontend_event event, void *data) diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 9cefd13..9205c96 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -36,6 +36,8 @@ struct transcription_filter_data { size_t sentence_number; // Minimal subtitle duration in ms size_t min_sub_duration; + // Maximal subtitle duration in ms + size_t max_sub_duration; // Last time a subtitle was rendered uint64_t last_sub_render_time; bool cleared_last_sub; diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index acecf4f..e689b8d 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -290,6 +290,30 @@ void add_buffered_output_group_properties(obs_properties_t *ppts) OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); obs_property_list_add_int(buffer_type_list, "Character", SEGMENTATION_TOKEN); obs_property_list_add_int(buffer_type_list, "Word", SEGMENTATION_WORD); + obs_property_list_add_int(buffer_type_list, "Sentence", SEGMENTATION_SENTENCE); + // add callback to the segmentation selection to set default values + obs_property_set_modified_callback(buffer_type_list, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + const int segmentation_type = obs_data_get_int(settings, "buffer_output_type"); + // set default values for the number of lines and characters per line + switch (segmentation_type) { + case SEGMENTATION_TOKEN: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 30); + break; + case SEGMENTATION_WORD: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 10); + break; + case SEGMENTATION_SENTENCE: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 2); + break; + } + return true; + }); // add buffer lines parameter obs_properties_add_int_slider(buffered_output_group, "buffer_num_lines", MT_("buffer_num_lines"), 1, 5, 1); @@ -310,6 +334,8 @@ void add_advanced_group_properties(obs_properties_t *ppts, struct transcription_ obs_properties_add_int_slider(advanced_config_group, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000, 50); + obs_properties_add_int_slider(advanced_config_group, "max_sub_duration", + MT_("max_sub_duration"), 1000, 5000, 50); obs_properties_add_float_slider(advanced_config_group, "sentence_psum_accept_thresh", MT_("sentence_psum_accept_thresh"), 0.0, 1.0, 0.05); @@ -535,7 +561,8 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_bool(s, "truncate_output_file", false); obs_data_set_default_bool(s, "only_while_recording", false); obs_data_set_default_bool(s, "rename_file_to_match_recording", true); - obs_data_set_default_int(s, "min_sub_duration", 3000); + obs_data_set_default_int(s, "min_sub_duration", 1000); + obs_data_set_default_int(s, "max_sub_duration", 3000); obs_data_set_default_bool(s, "advanced_settings", false); obs_data_set_default_bool(s, "translate", false); obs_data_set_default_string(s, "translate_target_language", "__es__"); diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index d47876a..357030f 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -187,6 +187,7 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->sentence_number = 1; gf->process_while_muted = obs_data_get_bool(s, "process_while_muted"); gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(s, "max_sub_duration"); gf->last_sub_render_time = now_ms(); gf->duration_filter_threshold = (float)obs_data_get_double(s, "duration_filter_threshold"); gf->partial_transcription = obs_data_get_bool(s, "partial_group"); @@ -432,6 +433,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / MAX_MS_WORK_BUFFER)); gf->last_num_frames = 0; gf->min_sub_duration = (int)obs_data_get_int(settings, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(settings, "max_sub_duration"); gf->last_sub_render_time = now_ms(); gf->log_level = (int)obs_data_get_int(settings, "log_level"); gf->save_srt = obs_data_get_bool(settings, "subtitle_save_srt"); diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index ac34534..dc4b3fa 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -6,6 +6,9 @@ #include "whisper-utils.h" #include "transcription-utils.h" +#include +#include + #include #ifdef _WIN32 @@ -75,37 +78,74 @@ void TokenBufferThread::log_token_vector(const std::vector &tokens) obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); } -void TokenBufferThread::addSentence(const std::string &sentence) +void TokenBufferThread::addSentenceFromStdString(const std::string &sentence, + TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial) { + if (sentence.empty()) { + return; + } #ifdef _WIN32 // on windows convert from multibyte to wide char int count = MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0); - std::wstring sentence_ws(count, 0); + TokenBufferString sentence_ws(count, 0); MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0], count); #else - std::string sentence_ws = sentence; + TokenBufferString sentence_ws = sentence; #endif - // split to characters - std::vector characters; - for (const auto &c : sentence_ws) { - characters.push_back(TokenBufferString(1, c)); + + TokenBufferSentence sentence_for_add; + sentence_for_add.start_time = start_time; + sentence_for_add.end_time = end_time; + + if (this->segmentation == SEGMENTATION_WORD) { + // split the sentence to words + std::vector words; + std::basic_istringstream iss(sentence_ws); + TokenBufferString word; + while (iss >> word) { + words.push_back(word); + } + // add the words to a sentence + for (const auto &word : words) { + sentence_for_add.tokens.push_back({word, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); + } + } else if (this->segmentation == SEGMENTATION_TOKEN) { + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(TokenBufferString(1, c)); + } + // add the characters to a sentece + for (const auto &character : characters) { + sentence_for_add.tokens.push_back({character, is_partial}); + } + } else { + // add the whole sentence as a single token + sentence_for_add.tokens.push_back({sentence_ws, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); } + addSentence(sentence_for_add); +} - std::lock_guard lock(inputQueueMutex); +void TokenBufferThread::addSentence(const TokenBufferSentence &sentence) +{ + std::lock_guard lock(this->inputQueueMutex); - // add the characters to the inputQueue - for (const auto &character : characters) { + // add the tokens to the inputQueue + for (const auto &character : sentence.tokens) { inputQueue.push_back(character); } - inputQueue.push_back(SPACE); + inputQueue.push_back({SPACE, sentence.tokens.back().is_partial}); // add to the contribution queue as well - for (const auto &character : characters) { + for (const auto &character : sentence.tokens) { contributionQueue.push_back(character); } - contributionQueue.push_back(SPACE); + contributionQueue.push_back({SPACE, sentence.tokens.back().is_partial}); this->lastContributionTime = std::chrono::steady_clock::now(); } @@ -148,7 +188,7 @@ void TokenBufferThread::monitor() if (this->segmentation == SEGMENTATION_TOKEN) { // pop tokens until a space is found while (!presentationQueue.empty() && - presentationQueue.front() != SPACE) { + presentationQueue.front().token != SPACE) { presentationQueue.pop_front(); } } @@ -158,6 +198,13 @@ void TokenBufferThread::monitor() std::lock_guard lock(inputQueueMutex); if (!inputQueue.empty()) { + // if the input on the inputQueue is partial - first remove all partials + // from the end of the presentation queue + while (!presentationQueue.empty() && + presentationQueue.back().is_partial) { + presentationQueue.pop_back(); + } + // if there are token on the input queue // then add to the presentation queue based on the segmentation if (this->segmentation == SEGMENTATION_SENTENCE) { @@ -171,16 +218,17 @@ void TokenBufferThread::monitor() presentationQueue.push_back(inputQueue.front()); inputQueue.pop_front(); } else { + // SEGMENTATION_WORD // skip spaces in the beginning of the input queue while (!inputQueue.empty() && - inputQueue.front() == SPACE) { + inputQueue.front().token == SPACE) { inputQueue.pop_front(); } // add one word to the presentation queue - TokenBufferString word; + TokenBufferToken word; while (!inputQueue.empty() && - inputQueue.front() != SPACE) { - word += inputQueue.front(); + inputQueue.front().token != SPACE) { + word = inputQueue.front(); inputQueue.pop_front(); } presentationQueue.push_back(word); @@ -200,7 +248,7 @@ void TokenBufferThread::monitor() size_t wordsInSentence = 0; for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &word = presentationQueue[i]; - sentences.back() += word + SPACE; + sentences.back() += word.token + SPACE; wordsInSentence++; if (wordsInSentence == this->numPerSentence) { sentences.push_back(TokenBufferString()); @@ -211,12 +259,12 @@ void TokenBufferThread::monitor() for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &token = presentationQueue[i]; // skip spaces in the beginning of a sentence (tokensInSentence == 0) - if (token == SPACE && + if (token.token == SPACE && sentences.back().length() == 0) { continue; } - sentences.back() += token; + sentences.back() += token.token; if (sentences.back().length() == this->numPerSentence) { // if the next character is not a space - this is a broken word @@ -280,7 +328,7 @@ void TokenBufferThread::monitor() // take the contribution queue and send it to the output TokenBufferString contribution; for (const auto &token : contributionQueue) { - contribution += token; + contribution += token.token; } contributionQueue.clear(); #ifdef _WIN32 diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 13be208..7666669 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -16,8 +16,10 @@ #ifdef _WIN32 typedef std::wstring TokenBufferString; +typedef wchar_t TokenBufferChar; #else typedef std::string TokenBufferString; +typedef char TokenBufferChar; #endif struct transcription_filter_data; @@ -27,6 +29,22 @@ enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST }; typedef std::chrono::time_point TokenBufferTimePoint; +inline std::chrono::time_point get_time_point_from_ms(uint64_t ms) +{ + return std::chrono::time_point(std::chrono::milliseconds(ms)); +} + +struct TokenBufferToken { + TokenBufferString token; + bool is_partial; +}; + +struct TokenBufferSentence { + std::vector tokens; + TokenBufferTimePoint start_time; + TokenBufferTimePoint end_time; +}; + class TokenBufferThread { public: // default constructor @@ -40,7 +58,9 @@ class TokenBufferThread { std::chrono::seconds maxTime_, TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN); - void addSentence(const std::string &sentence); + void addSentenceFromStdString(const std::string &sentence, TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial = false); + void addSentence(const TokenBufferSentence &sentence); void clear(); void stopThread(); @@ -59,9 +79,9 @@ class TokenBufferThread { void log_token_vector(const std::vector &tokens); int getWaitTime(TokenBufferSpeed speed) const; struct transcription_filter_data *gf; - std::deque inputQueue; - std::deque presentationQueue; - std::deque contributionQueue; + std::deque inputQueue; + std::deque presentationQueue; + std::deque contributionQueue; std::thread workerThread; std::mutex inputQueueMutex; std::mutex presentationQueueMutex; diff --git a/src/whisper-utils/vad-processing.cpp b/src/whisper-utils/vad-processing.cpp index 2a1f8d4..9f104b9 100644 --- a/src/whisper-utils/vad-processing.cpp +++ b/src/whisper-utils/vad-processing.cpp @@ -5,6 +5,11 @@ #include "vad-processing.h" +#ifdef _WIN32 +#define NOMINMAX +#include +#endif + vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) { uint32_t num_frames_from_infos = 0; @@ -96,7 +101,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v } const size_t vad_window_size_samples = gf->vad->get_window_size_samples() * sizeof(float); - const size_t min_vad_buffer_size = vad_window_size_samples * 15; + const size_t min_vad_buffer_size = vad_window_size_samples * 8; if (gf->resampled_buffer.size < min_vad_buffer_size) return last_vad_state; @@ -156,11 +161,11 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v } int end_frame = stamps[i].end; - if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { - // take at least 100ms of audio after the last speech segment, if available - end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, - (int)vad_input.size()); - } + // if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { + // // take at least 100ms of audio after the last speech segment, if available + // end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, + // (int)vad_input.size()); + // } const int number_of_frames = end_frame - start_frame; @@ -196,12 +201,22 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v // end not reached - speech is ongoing current_vad_state.vad_on = true; if (last_vad_state.vad_on) { + obs_log(gf->log_level, + "last vad state was: ON, start ts: %llu, end ts: %llu", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms); current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; } else { + obs_log(gf->log_level, + "last vad state was: OFF, start ts: %llu, end ts: %llu. start_ts_offset_ms: %llu, start_frame: %d", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, + start_ts_offset_ms, start_frame); current_vad_state.start_ts_offest_ms = start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; } - obs_log(gf->log_level, "end not reached. vad state: start ts: %llu, end ts: %llu", + current_vad_state.end_ts_offset_ms = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + obs_log(gf->log_level, + "end not reached. vad state: ON, start ts: %llu, end ts: %llu", current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); last_vad_state = current_vad_state; @@ -235,3 +250,24 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v return current_vad_state; } + +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file) +{ + // initialize Silero VAD +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, + strlen(silero_vad_model_file), NULL, 0); + std::wstring silero_vad_model_path(count, 0); + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + &silero_vad_model_path[0], count); + obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); +#else + std::string silero_vad_model_path = silero_vad_model_file; + obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); +#endif + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 32, 0.5f, 200, + 100, 100)); +} diff --git a/src/whisper-utils/vad-processing.h b/src/whisper-utils/vad-processing.h index 237abe6..1593d8a 100644 --- a/src/whisper-utils/vad-processing.h +++ b/src/whisper-utils/vad-processing.h @@ -11,5 +11,6 @@ struct vad_state { }; vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file); #endif // VAD_PROCESSING_H diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 17e817a..1a1e700 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -361,7 +361,7 @@ void whisper_loop(void *data) if (!gf->cleared_last_sub) { // check if we should clear the current sub depending on the minimum subtitle duration uint64_t now = now_ms(); - if ((now - gf->last_sub_render_time) > gf->min_sub_duration) { + if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { // clear the current sub, call the callback with an empty string obs_log(gf->log_level, "Clearing current subtitle. now: %lu ms, last: %lu ms", now, diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index c2e4929..84f3b0a 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -2,13 +2,10 @@ #include "plugin-support.h" #include "model-utils/model-downloader.h" #include "whisper-processing.h" +#include "vad-processing.h" #include -#ifdef _WIN32 -#include -#endif - void shutdown_whisper_thread(struct transcription_filter_data *gf) { obs_log(gf->log_level, "shutdown_whisper_thread"); @@ -40,21 +37,7 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, } // initialize Silero VAD -#ifdef _WIN32 - // convert mbstring to wstring - int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, - strlen(silero_vad_model_file), NULL, 0); - std::wstring silero_vad_model_path(count, 0); - MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), - &silero_vad_model_path[0], count); - obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); -#else - std::string silero_vad_model_path = silero_vad_model_file; - obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); -#endif - // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py - // for silero vad parameters - gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + initialize_vad(gf, silero_vad_model_file); obs_log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf);