diff --git a/CMakeLists.txt b/CMakeLists.txt index cea698e..6784e0f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,32 @@ include(cmake/BuildICU.cmake) target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU) target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR}) +if(WIN32 OR APPLE) + if(NOT buildspec) + file(READ "${CMAKE_CURRENT_SOURCE_DIR}/buildspec.json" buildspec) + endif() + string( + JSON + version + GET + ${buildspec} + dependencies + prebuilt + version) + if(MSVC) + set(arch ${CMAKE_GENERATOR_PLATFORM}) + elseif(APPLE) + set(arch universal) + endif() + set(deps_root "${CMAKE_CURRENT_SOURCE_DIR}/.deps/obs-deps-${version}-${arch}") + message(STATUS "deps_root: ${deps_root}") + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE "${deps_root}/include") +else() + include(cmake/FetchWebsocketpp.cmake) + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE websocketpp) + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE ${asio_SOURCE_DIR}/asio/include/) +endif() + target_sources( ${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c @@ -120,12 +146,15 @@ target_sources( src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp src/whisper-utils/vad-processing.cpp + src/whisper-utils/resample-utils.cpp src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp src/ui/filter-replace-utils.cpp src/translation/translation-language-utils.cpp - src/ui/filter-replace-dialog.cpp) + src/ui/filter-replace-dialog.cpp + src/stenographer/stenographer.cpp + src/stenographer/stenographer-util.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) diff --git a/cmake/FetchWebsocketpp.cmake b/cmake/FetchWebsocketpp.cmake new file mode 100644 index 0000000..4ee7d46 --- /dev/null +++ b/cmake/FetchWebsocketpp.cmake @@ -0,0 +1,24 @@ +include(FetchContent) + +FetchContent_Declare( + websocketpp + URL https://github.com/zaphoyd/websocketpp/archive/refs/tags/0.8.2.tar.gz + URL_HASH SHA256=6ce889d85ecdc2d8fa07408d6787e7352510750daa66b5ad44aacb47bea76755) + +# Only download the content, don't configure or build it +FetchContent_GetProperties(websocketpp) +if(NOT websocketpp_POPULATED) + FetchContent_Populate(websocketpp) +endif() + +# Add WebSocket++ as an interface library +add_library(websocketpp INTERFACE) +target_include_directories(websocketpp INTERFACE ${websocketpp_SOURCE_DIR}) + +# Fetch ASIO +FetchContent_Declare( + asio + URL https://github.com/chriskohlhoff/asio/archive/asio-1-28-0.tar.gz + URL_HASH SHA256=226438b0798099ad2a202563a83571ce06dd13b570d8fded4840dbc1f97fa328) + +FetchContent_MakeAvailable(websocketpp asio) diff --git a/cmake/linux/compilerconfig.cmake b/cmake/linux/compilerconfig.cmake index 647c4b3..8931ba3 100644 --- a/cmake/linux/compilerconfig.cmake +++ b/cmake/linux/compilerconfig.cmake @@ -21,6 +21,7 @@ set(_obs_gcc_c_options -Wformat-security -Wno-conversion -Wno-deprecated-declarations + -Wno-error=conversion -Wno-error=deprecated-declarations -Wno-float-conversion -Wno-implicit-fallthrough @@ -42,14 +43,13 @@ set(_obs_gcc_c_options -Wvla) # gcc options for C++ -set(_obs_gcc_cxx_options - # cmake-format: sortable - ${_obs_gcc_c_options} -Wconversion -Wfloat-conversion -Winvalid-offsetof -Wno-overloaded-virtual) +set(_obs_gcc_cxx_options # cmake-format: sortable + ${_obs_gcc_c_options} -Winvalid-offsetof -Wno-overloaded-virtual) add_compile_options( -fopenmp-simd "$<$:${_obs_gcc_c_options}>" - "$<$:-Wint-conversion;-Wno-missing-prototypes;-Wno-strict-prototypes;-Wpointer-sign>" + "$<$:-Wno-missing-prototypes;-Wno-strict-prototypes;-Wpointer-sign>" "$<$:${_obs_gcc_cxx_options}>" "$<$:${_obs_clang_c_options}>" "$<$:${_obs_clang_cxx_options}>") diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 9ef4d18..338f880 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -87,4 +87,7 @@ Active_VAD="Active VAD" Hybrid_VAD="Hybrid VAD" translate_only_full_sentences="Translate only full sentences" duration_filter_threshold="Duration filter" -segment_duration="Segment duration" \ No newline at end of file +segment_duration="Segment duration" +stenographer_parameters="Stenographer Options" +stenographer_delay="Audio Delay" +translation_remove_punctuation_from_start="Remove punctuation from sentence start" \ No newline at end of file diff --git a/src/stenographer/stenographer-util.cpp b/src/stenographer/stenographer-util.cpp new file mode 100644 index 0000000..59278cb --- /dev/null +++ b/src/stenographer/stenographer-util.cpp @@ -0,0 +1,67 @@ + +#include "stenographer-util.h" +#include "transcription-filter-data.h" +#include "transcription-utils.h" + +#include + +#include +#include + +/** + * @brief Applies a simple delay to the audio data for stenographer mode. + * + * This function stores the incoming audio data in a buffer and processes it after a specified delay. + * The delayed audio data is then emitted, replacing the original audio data in the buffer. + * If the buffer does not yet contain enough data to satisfy the delay, the audio buffer is filled with silence. + * + * @param gf Pointer to the transcription filter data structure containing the delay buffer and configuration. + * @param audio Pointer to the audio data structure containing the audio frames to be processed. + * @return Pointer to the processed audio data structure with the applied delay. + */ +struct obs_audio_data *stenographer_simple_delay(transcription_filter_data *gf, + struct obs_audio_data *audio) +{ + // Stenographer mode - apply delay. + // Store the audio data in a buffer and process it after the delay. + // push the data to the back of gf->stenographer_delay_buffer + for (size_t c = 0; c < gf->channels; c++) { + // take a audio->frames * sizeof(float) bytes chunk from audio->data[c] and push it + // to the back of the buffer as a float + std::vector audio_data_chunk((float *)audio->data[c], + ((float *)audio->data[c]) + audio->frames); + gf->stenographer_delay_buffers[c].insert(gf->stenographer_delay_buffers[c].end(), + audio_data_chunk.begin(), + audio_data_chunk.end()); + } + + // If the buffer is larger than the delay, emit the oldest data + // Take from the buffer as much as requested by the incoming audio data + size_t delay_frames = + (size_t)((float)gf->sample_rate * (float)gf->stenographer_delay_ms / 1000.0f) + + audio->frames; + + if (gf->stenographer_delay_buffers[0].size() >= delay_frames) { + // Replace data on the audio buffer with the delayed data + for (size_t c = 0; c < gf->channels; c++) { + // take exatcly audio->frames from the buffer + std::vector audio_data(gf->stenographer_delay_buffers[c].begin(), + gf->stenographer_delay_buffers[c].begin() + + audio->frames); + // remove the oldest buffers from the delay buffer + gf->stenographer_delay_buffers[c].erase( + gf->stenographer_delay_buffers[c].begin(), + gf->stenographer_delay_buffers[c].begin() + audio->frames); + + // replace the data on the audio buffer with the delayed data + memcpy(audio->data[c], audio_data.data(), + audio_data.size() * sizeof(float)); + } + } else { + // Fill the audio buffer with silence + for (size_t c = 0; c < gf->channels; c++) { + memset(audio->data[c], 0, audio->frames * sizeof(float)); + } + } + return audio; +} diff --git a/src/stenographer/stenographer-util.h b/src/stenographer/stenographer-util.h new file mode 100644 index 0000000..3d0fd27 --- /dev/null +++ b/src/stenographer/stenographer-util.h @@ -0,0 +1,10 @@ +#ifndef STENOGRAPHER_UTIL_H +#define STENOGRAPHER_UTIL_H + +struct transcription_filter_data; +struct obs_audio_data; + +struct obs_audio_data *stenographer_simple_delay(transcription_filter_data *gf, + struct obs_audio_data *audio); + +#endif /* STENOGRAPHER_UTIL_H */ \ No newline at end of file diff --git a/src/stenographer/stenographer.cpp b/src/stenographer/stenographer.cpp new file mode 100644 index 0000000..95550eb --- /dev/null +++ b/src/stenographer/stenographer.cpp @@ -0,0 +1,221 @@ +#include + +#include "stenographer.h" +#include "plugin-support.h" +#include "whisper-utils/resample-utils.h" +#include "transcription-utils.h" + +#define ASIO_STANDALONE +#define _WEBSOCKETPP_CPP11_TYPE_TRAITS_ + +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; +typedef websocketpp::server server; + +// WAV header structure +struct WAVHeader { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t overall_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt_chunk_marker[4] = {'f', 'm', 't', ' '}; + uint32_t length_of_fmt = 16; + uint16_t format_type = 1; + uint16_t channels = 1; + uint32_t sample_rate = 16000; + uint32_t byterate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data_chunk_header[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +class TranscriptionHandler::Impl { +public: + using MessageCallback = TranscriptionHandler::MessageCallback; + + explicit Impl(transcription_filter_data *gf_, MessageCallback callback) + : gf(gf_), + messageCallback(callback), + running(false) + { + wsServer.init_asio(); + + wsServer.set_open_handler([this](websocketpp::connection_hdl hdl) { + std::lock_guard lock(mutex); + connection = hdl; + }); + + wsServer.set_message_handler( + [this](websocketpp::connection_hdl hdl, server::message_ptr msg) { + UNUSED_PARAMETER(hdl); + handleIncomingMessage(msg->get_payload()); + }); + + // Initialize WAV header + wavHeader.byterate = + wavHeader.sample_rate * wavHeader.channels * wavHeader.bits_per_sample / 8; + wavHeader.block_align = wavHeader.channels * wavHeader.bits_per_sample / 8; + } + + void start() + { + if (!running) { + running = true; + serverThread = std::async(std::launch::async, [this]() { + wsServer.listen(9002); + wsServer.start_accept(); + wsServer.run(); + }); + + processingThread = + std::async(std::launch::async, [this]() { processAudioQueue(); }); + } + } + + void stop() + { + if (running) { + running = false; + wsServer.stop(); + if (serverThread.valid()) + serverThread.wait(); + if (processingThread.valid()) + processingThread.wait(); + } + } + +private: + transcription_filter_data *gf; + server wsServer; + websocketpp::connection_hdl connection; + MessageCallback messageCallback; + std::queue> audioQueue; + std::mutex mutex; + std::atomic running; + std::future serverThread; + std::future processingThread; + + void handleIncomingMessage(const std::string &message) + { + try { + json j = json::parse(message); + std::string type = j["type"].get(); + std::string text = j["text"].get(); + + uint64_t start_timestamp = j["start_timestamp"].get(); + uint64_t end_timestamp = j["end_timestamp"].get(); + + messageCallback(type, text, start_timestamp, end_timestamp); + } catch (json::parse_error &e) { + obs_log(LOG_ERROR, "Failed to parse JSON message: %s", e.what()); + } catch (json::type_error &e) { + obs_log(LOG_ERROR, "Failed to parse JSON message: %s", e.what()); + } + } + + void processAudioQueue() + { + while (running) { + // get data from buffer and resample to 16kHz + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + const int ret = get_data_from_buf_and_resample( + gf, start_timestamp_offset_ns, end_timestamp_offset_ns); + if (ret != 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + + std::vector audio_input; + audio_input.resize(gf->resampled_buffer.size / sizeof(float)); + circlebuf_pop_front(&gf->resampled_buffer, audio_input.data(), + audio_input.size() * sizeof(float)); + + std::vector pcmData(audio_input.size()); + for (size_t i = 0; i < audio_input.size(); ++i) { + pcmData[i] = static_cast(audio_input[i] * 32767.0f); + } + + if (!pcmData.empty()) { + json timestampInfo = {{"start_timestamp", + start_timestamp_offset_ns}, + {"end_timestamp", end_timestamp_offset_ns}}; + if (connection.lock()) { + wsServer.send(connection, timestampInfo.dump(), + websocketpp::frame::opcode::text); + } + sendAudioData(pcmData); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + if (!gf->cleared_last_sub) { + // check if we should clear the current sub depending on the minimum subtitle duration + uint64_t now = now_ms(); + if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { + // clear the current sub, call the callback with an empty string + obs_log(gf->log_level, + "Clearing current subtitle. now: %lu ms, last: %lu ms", + now, gf->last_sub_render_time); + clear_current_caption(gf); + } + } + } + } + + WAVHeader wavHeader; + std::vector audioBuffer; + + void sendAudioData(const std::vector &audioData) + { + std::lock_guard lock(mutex); + if (connection.lock()) { + audioBuffer.insert(audioBuffer.end(), audioData.begin(), audioData.end()); + + // If we have accumulated enough data, send it as a WAV file + if (audioBuffer.size() >= 8000) { // 0.5 seconds of audio at 16kHz + wavHeader.data_size = audioBuffer.size() * sizeof(int16_t); + wavHeader.overall_size = + wavHeader.data_size + sizeof(WAVHeader) - 8; + + std::vector wavData(sizeof(WAVHeader) + + wavHeader.data_size); + std::memcpy(wavData.data(), &wavHeader, sizeof(WAVHeader)); + std::memcpy(wavData.data() + sizeof(WAVHeader), audioBuffer.data(), + wavHeader.data_size); + + wsServer.send(connection, wavData.data(), wavData.size(), + websocketpp::frame::opcode::binary); + + audioBuffer.clear(); + } + } + } +}; + +TranscriptionHandler::TranscriptionHandler(transcription_filter_data *gf_, MessageCallback callback) + : pimpl(std::make_unique(std::move(gf_), std::move(callback))) +{ +} + +TranscriptionHandler::~TranscriptionHandler() = default; + +TranscriptionHandler::TranscriptionHandler(TranscriptionHandler &&) noexcept = default; +TranscriptionHandler &TranscriptionHandler::operator=(TranscriptionHandler &&) noexcept = default; + +void TranscriptionHandler::start() +{ + pimpl->start(); +} +void TranscriptionHandler::stop() +{ + pimpl->stop(); +} diff --git a/src/stenographer/stenographer.h b/src/stenographer/stenographer.h new file mode 100644 index 0000000..cb7a26e --- /dev/null +++ b/src/stenographer/stenographer.h @@ -0,0 +1,33 @@ +#pragma once + +// Forward declaration +struct transcription_filter_data; + +#include +#include +#include + +class TranscriptionHandler { +public: + using MessageCallback = + std::function; + + explicit TranscriptionHandler(transcription_filter_data *gf_, MessageCallback callback); + ~TranscriptionHandler(); + + // Disable copy + TranscriptionHandler(const TranscriptionHandler &) = delete; + TranscriptionHandler &operator=(const TranscriptionHandler &) = delete; + + // Enable move + TranscriptionHandler(TranscriptionHandler &&) noexcept; + TranscriptionHandler &operator=(TranscriptionHandler &&) noexcept; + + void start(); + void stop(); + +private: + class Impl; + std::unique_ptr pimpl; +}; \ No newline at end of file diff --git a/src/stenographer/stenographer_interface.html b/src/stenographer/stenographer_interface.html new file mode 100644 index 0000000..c9f99b0 --- /dev/null +++ b/src/stenographer/stenographer_interface.html @@ -0,0 +1,232 @@ + + + + + + Stenographer Interface + + + +

Stenographer Interface

+
+ + +
+ +
Timestamp (s):
+ +
Connection Status: Disconnected
+
Audio Status: Not started
+ + + + + \ No newline at end of file diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index c4a9042..ba57a19 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -352,6 +352,21 @@ void set_text_callback(struct transcription_filter_data *gf, } }; +void clear_current_caption(transcription_filter_data *gf_) +{ + if (gf_->captions_monitor.isEnabled()) { + gf_->captions_monitor.clear(); + gf_->translation_monitor.clear(); + } + // reset translation context + gf_->last_text_for_translation = ""; + gf_->last_text_translation = ""; + gf_->translation_ctx.last_input_tokens.clear(); + gf_->translation_ctx.last_translation_tokens.clear(); + gf_->last_transcription_sentence.clear(); + gf_->cleared_last_sub = true; +} + void release_context(transcription_filter_data *gf) { obs_log(LOG_INFO, "destroy"); diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 0095c54..0b9a8b2 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -239,8 +239,10 @@ void set_text_callback(struct transcription_filter_data *gf, gf->translate_only_full_sentences ? result.result == DETECTION_RESULT_SPEECH : true; // send the sentence to translation (if enabled) + std::string source_language = result.language.empty() ? gf->whisper_params.language + : result.language; std::string translated_sentence = - should_translate ? send_sentence_to_translation(str_copy, gf, result.language) : ""; + should_translate ? send_sentence_to_translation(str_copy, gf, source_language) : ""; if (gf->translate) { if (gf->translation_output == "none") { @@ -395,6 +397,7 @@ void reset_caption_state(transcription_filter_data *gf_) if (gf_->input_buffers[c].data != nullptr) { circlebuf_free(&gf_->input_buffers[c]); } + gf_->stenographer_delay_buffers[c].clear(); } if (gf_->info_buffer.data != nullptr) { circlebuf_free(&gf_->info_buffer); @@ -452,15 +455,16 @@ void enable_callback(void *data_, calldata_t *cd) { transcription_filter_data *gf_ = static_cast(data_); bool enable = calldata_bool(cd, "enabled"); + reset_caption_state(gf_); if (enable) { obs_log(gf_->log_level, "enable_callback: enable"); gf_->active = true; - reset_caption_state(gf_); - update_whisper_model(gf_); + if (!gf_->stenographer_enabled) { + update_whisper_model(gf_); + } } else { obs_log(gf_->log_level, "enable_callback: disable"); gf_->active = false; - reset_caption_state(gf_); shutdown_whisper_thread(gf_); } } diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index e8990be..996faa5 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -19,12 +19,13 @@ #include "whisper-utils/silero-vad-onnx.h" #include "whisper-utils/whisper-processing.h" #include "whisper-utils/token-buffer-thread.h" +#include "stenographer/stenographer.h" #define MAX_PREPROC_CHANNELS 10 struct transcription_filter_data { obs_source_t *context; // obs filter source (this filter) - size_t channels; // number of channels + size_t channels; // number of channels in the input uint32_t sample_rate; // input sample rate // How many input frames (in input sample rate) are needed for the next whisper frame size_t frames; @@ -128,6 +129,11 @@ struct transcription_filter_data { TokenBufferSegmentation buffered_output_output_type = TokenBufferSegmentation::SEGMENTATION_TOKEN; + bool stenographer_enabled = false; + TranscriptionHandler *transcription_handler = nullptr; + int stenographer_delay_ms = 1000; + std::deque stenographer_delay_buffers[MAX_PREPROC_CHANNELS]; + // ctor transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() { diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 0f96e30..434c63c 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -259,6 +259,9 @@ void add_translation_group_properties(obs_properties_t *ppts) MT_("translation_max_input_length"), 1, 100, 5); obs_properties_add_int_slider(translation_group, "translation_no_repeat_ngram_size", MT_("translation_no_repeat_ngram_size"), 1, 10, 1); + // add remove_punctuation_from_start boolean + obs_properties_add_bool(translation_group, "translation_remove_punctuation_from_start", + MT_("translation_remove_punctuation_from_start")); } void add_file_output_group_properties(obs_properties_t *ppts) @@ -506,6 +509,18 @@ void add_general_group_properties(obs_properties_t *ppts) } } +void add_stenographer_group_properties(obs_properties_t *ppts) +{ + // add group for stenographer options + obs_properties_t *stenographer_group = obs_properties_create(); + obs_properties_add_group(ppts, "stenographer_group", MT_("stenographer_parameters"), + OBS_GROUP_CHECKABLE, stenographer_group); + + // add delay amount for partial transcription + obs_properties_add_int_slider(stenographer_group, "stenographer_delay", + MT_("stenographer_delay"), 0, 12000, 100); +} + void add_partial_group_properties(obs_properties_t *ppts) { // add a group for partial transcription @@ -546,6 +561,7 @@ obs_properties_t *transcription_filter_properties(void *data) add_advanced_group_properties(ppts, gf); add_logging_group_properties(ppts); add_partial_group_properties(ppts); + add_stenographer_group_properties(ppts); add_whisper_params_group_properties(ppts); // Add a informative text about the plugin @@ -596,6 +612,8 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); obs_data_set_default_bool(s, "partial_group", true); obs_data_set_default_int(s, "partial_latency", 1100); + obs_data_set_default_bool(s, "stenographer_group", false); + obs_data_set_default_int(s, "stenographer_delay", 3000); // translation options obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); @@ -604,6 +622,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_int(s, "translation_max_decoding_length", 65); obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); obs_data_set_default_int(s, "translation_max_input_length", 65); + obs_data_set_default_bool(s, "translation_remove_punctuation_from_start", false); // Whisper parameters obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 9d4ebf5..88cbe28 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -31,6 +31,7 @@ #include "translation/translation-includes.h" #include "ui/filter-replace-dialog.h" #include "ui/filter-replace-utils.h" +#include "stenographer/stenographer-util.h" void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) { @@ -80,7 +81,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ return audio; } - if (gf->whisper_context == nullptr) { + if (gf->whisper_context == nullptr && !gf->stenographer_enabled) { // Whisper not initialized, just pass through return audio; } @@ -103,6 +104,8 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ circlebuf_push_back(&gf->input_buffers[c], audio->data[c], audio->frames * sizeof(float)); } + obs_log(LOG_DEBUG, "currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); // push audio packet info (timestamp/frame count) to info circlebuf struct transcription_filter_audio_info info = {0}; info.frames = audio->frames; // number of frames in this packet @@ -112,6 +115,11 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ gf->wshiper_thread_cv.notify_one(); } + if (gf->stenographer_enabled) { + // Stenographer mode - apply delay. + return stenographer_simple_delay(gf, audio); + } + return audio; } @@ -164,6 +172,10 @@ void transcription_filter_destroy(void *data) if (gf->translation_monitor.isEnabled()) { gf->translation_monitor.stopThread(); } + if (gf->transcription_handler != nullptr) { + gf->transcription_handler->stop(); + delete gf->transcription_handler; + } bfree(gf); } @@ -293,6 +305,9 @@ void transcription_filter_update(void *data, obs_data_t *s) std::string new_translate_model_index = obs_data_get_string(s, "translate_model"); std::string new_translation_model_path_external = obs_data_get_string(s, "translation_model_path_external"); + gf->translation_ctx.remove_punctuation_from_start = + obs_data_get_bool(s, "translation_remove_punctuation_from_start"); + gf->translation_ctx.log_level = gf->log_level; if (new_translate) { if (new_translate != gf->translate || @@ -404,7 +419,11 @@ void transcription_filter_update(void *data, obs_data_t *s) } } - if (gf->context != nullptr && (obs_source_enabled(gf->context) || gf->initial_creation)) { + // check if stenographer is enabled + bool new_stenographer_enabled = obs_data_get_bool(s, "stenographer_group"); + + if (!new_stenographer_enabled && gf->context != nullptr && + (obs_source_enabled(gf->context) || gf->initial_creation)) { if (gf->initial_creation) { obs_log(LOG_INFO, "Initial filter creation and source enabled"); @@ -424,7 +443,38 @@ void transcription_filter_update(void *data, obs_data_t *s) } } } else { - obs_log(LOG_INFO, "Filter not enabled, not updating whisper model."); + obs_log(LOG_INFO, "Transcription not enabled, not updating whisper model."); + } + + if (new_stenographer_enabled != gf->stenographer_enabled) { + gf->stenographer_enabled = new_stenographer_enabled; + if (gf->stenographer_enabled) { + obs_log(gf->log_level, "Stenographer enabled"); + shutdown_whisper_thread(gf); // stop whisper + gf->stenographer_delay_ms = (int)obs_data_get_int(s, "stenographer_delay"); + gf->transcription_handler = new TranscriptionHandler( + gf, [gf](const std::string &type, const std::string &text, + uint64_t start_timestamp, uint64_t end_timestamp) { + DetectionResultWithText result; + result.text = text; + result.result = + (type == "partial") + ? DetectionResult::DETECTION_RESULT_PARTIAL + : DetectionResult::DETECTION_RESULT_SPEECH; + result.start_timestamp_ms = start_timestamp; + result.end_timestamp_ms = end_timestamp; + set_text_callback(gf, result); + }); + gf->transcription_handler->start(); + } else { + obs_log(gf->log_level, "Stenographer disabled, restarting whisper"); + if (gf->transcription_handler) { + gf->transcription_handler->stop(); + delete gf->transcription_handler; + gf->transcription_handler = nullptr; + } + update_whisper_model(gf); // restart whisper + } } } diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp index 727d3df..48c284a 100644 --- a/src/transcription-utils.cpp +++ b/src/transcription-utils.cpp @@ -82,6 +82,36 @@ std::string fix_utf8(const std::string &str) #endif } +/** + * @brief Trims leading and trailing whitespace characters from the given string. + * + * This function removes any whitespace characters (spaces, tabs, newlines, etc.) + * from the beginning and end of the input string, returning a new string with + * the whitespace removed. + * + * @param str The input string to be trimmed. + * @return A new string with leading and trailing whitespace removed. + */ +std::string trim(const std::string& str) { + std::string str_copy = str; + + // remove trailing spaces, newlines, tabs or punctuation + auto last_non_space = std::find_if(str_copy.rbegin(), str_copy.rend(), + [](unsigned char ch) { + return !std::isspace(ch) && !std::ispunct(ch); + }).base(); + str_copy.erase(last_non_space, str_copy.end()); + + // remove leading spaces, newlines, tabs or punctuation + auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(), + [](unsigned char ch) { + return !std::isspace(ch) && !std::ispunct(ch); + }); + str_copy.erase(str_copy.begin(), first_non_space); + + return str_copy; +} + /* * Remove leading and trailing non-alphabetic characters from a string. * This function is used to remove leading and trailing spaces, newlines, tabs or punctuation. @@ -111,21 +141,7 @@ std::string remove_leading_trailing_nonalpha(const std::string &str) return ""; } } - std::string str_copy = str; - // remove trailing spaces, newlines, tabs or punctuation - auto last_non_space = - std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) { - return !std::isspace(ch) || !std::ispunct(ch); - }).base(); - str_copy.erase(last_non_space, str_copy.end()); - // remove leading spaces, newlines, tabs or punctuation - auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(), - [](unsigned char ch) { - return !std::isspace(ch) || !std::ispunct(ch); - }) + - 1; - str_copy.erase(str_copy.begin(), first_non_space); - return str_copy; + return trim(str); } std::vector split(const std::string &string, char delimiter) diff --git a/src/transcription-utils.h b/src/transcription-utils.h index 5fdd0cf..437cd6d 100644 --- a/src/transcription-utils.h +++ b/src/transcription-utils.h @@ -32,6 +32,34 @@ inline uint64_t now_ns() .count(); } +/** + * @brief Calculates the elapsed time in nanoseconds since a given start time. + * + * This function takes a starting time in nanoseconds and returns the + * difference between the current time and the starting time. + * + * @param start_ns The starting time in nanoseconds. + * @return The elapsed time in nanoseconds since the start time. + */ +inline uint64_t ns_since(uint64_t start_ns) +{ + return now_ns() - start_ns; +} + +/** + * @brief Calculates the elapsed time in milliseconds since a given start time. + * + * This function takes a start time in milliseconds and returns the difference + * between the current time (in milliseconds) and the start time. + * + * @param start_ms The start time in milliseconds. + * @return The elapsed time in milliseconds since the start time. + */ +inline uint64_t ms_since(uint64_t start_ms) +{ + return now_ms() - start_ms; +} + // Split a string into words based on spaces std::vector split_words(const std::string &str_copy); diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index 0701d95..c495d0e 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -117,6 +117,9 @@ int translate(struct translation_context &translation_ctx, const std::string &te std::vector input_tokens = {source_lang, ""}; if (translation_ctx.add_context > 0 && translation_ctx.last_input_tokens.size() > 0) { + obs_log(translation_ctx.log_level, + "Adding last input tokens to input tokens, size: %d", + (int)translation_ctx.last_input_tokens.size()); // add the last input tokens sentences to the input tokens for (const auto &tokens : translation_ctx.last_input_tokens) { input_tokens.insert(input_tokens.end(), tokens.begin(), @@ -133,13 +136,24 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : input_tokens) { input_tokens_str += token + ", "; } - obs_log(LOG_INFO, "Input tokens: %s", input_tokens_str.c_str()); - - translation_ctx.last_input_tokens.push_back(new_input_tokens); - // remove the oldest input tokens - while (translation_ctx.last_input_tokens.size() > - (size_t)translation_ctx.add_context) { - translation_ctx.last_input_tokens.pop_front(); + obs_log(translation_ctx.log_level, "Input tokens: %s", + input_tokens_str.c_str()); + + if (translation_ctx.add_context > 0) { + translation_ctx.last_input_tokens.push_back(new_input_tokens); + obs_log(translation_ctx.log_level, + "Adding last input context. Last input tokens deque size: %d", + (int)translation_ctx.last_input_tokens.size()); + // remove the oldest input tokens + while (translation_ctx.last_input_tokens.size() > + (size_t)translation_ctx.add_context) { + obs_log(translation_ctx.log_level, + "Removing oldest input tokens context, size: %d", + (int)translation_ctx.last_input_tokens.size()); + translation_ctx.last_input_tokens.pop_front(); + } + } else { + translation_ctx.last_input_tokens.clear(); } const std::vector> batch = {input_tokens}; @@ -149,6 +163,9 @@ int translate(struct translation_context &translation_ctx, const std::string &te // add the last translation tokens to the target prefix if (translation_ctx.add_context > 0 && translation_ctx.last_translation_tokens.size() > 0) { + obs_log(translation_ctx.log_level, + "Adding last translation tokens to target prefix, size: %d", + (int)translation_ctx.last_translation_tokens.size()); for (const auto &tokens : translation_ctx.last_translation_tokens) { target_prefix.insert(target_prefix.end(), tokens.begin(), tokens.end()); @@ -160,7 +177,8 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : target_prefix) { target_prefix_str += token + ","; } - obs_log(LOG_INFO, "Target prefix: %s", target_prefix_str.c_str()); + obs_log(translation_ctx.log_level, "Target prefix: %s", + target_prefix_str.c_str()); const std::vector> target_prefix_batch = { target_prefix}; @@ -189,21 +207,33 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : translation_tokens) { translation_tokens_str += token + ", "; } - obs_log(LOG_INFO, "Translation tokens: %s", translation_tokens_str.c_str()); - - // save the translation tokens - translation_ctx.last_translation_tokens.push_back(translation_tokens); - // remove the oldest translation tokens - while (translation_ctx.last_translation_tokens.size() > - (size_t)translation_ctx.add_context) { - translation_ctx.last_translation_tokens.pop_front(); + obs_log(translation_ctx.log_level, "Translation tokens: %s", + translation_tokens_str.c_str()); + + if (translation_ctx.add_context > 0) { + // save the translation tokens + translation_ctx.last_translation_tokens.push_back(translation_tokens); + // remove the oldest translation tokens + while (translation_ctx.last_translation_tokens.size() > + (size_t)translation_ctx.add_context) { + obs_log(translation_ctx.log_level, + "Removing oldest translation tokens context, size: %d", + (int)translation_ctx.last_translation_tokens.size()); + translation_ctx.last_translation_tokens.pop_front(); + } + obs_log(translation_ctx.log_level, "Last translation tokens deque size: %d", + (int)translation_ctx.last_translation_tokens.size()); + } else { + translation_ctx.last_translation_tokens.clear(); } - obs_log(LOG_INFO, "Last translation tokens deque size: %d", - (int)translation_ctx.last_translation_tokens.size()); // detokenize const std::string result_ = translation_ctx.detokenizer(translation_tokens); - result = remove_start_punctuation(result_); + if (translation_ctx.remove_punctuation_from_start) { + result = remove_start_punctuation(result_); + } else { + result = result_; + } } catch (std::exception &e) { obs_log(LOG_ERROR, "Error: %s", e.what()); return OBS_POLYGLOT_TRANSLATION_FAIL; diff --git a/src/translation/translation.h b/src/translation/translation.h index c740726..adbf4b2 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -31,6 +31,8 @@ struct translation_context { // How many sentences to use as context for the next translation int add_context; InputTokenizationStyle input_tokenization_style; + bool remove_punctuation_from_start; + int log_level = 400; }; int build_translation_context(struct translation_context &translation_ctx); diff --git a/src/whisper-utils/resample-utils.cpp b/src/whisper-utils/resample-utils.cpp new file mode 100644 index 0000000..2de5edc --- /dev/null +++ b/src/whisper-utils/resample-utils.cpp @@ -0,0 +1,97 @@ +#include + +#include "resample-utils.h" + +int get_data_from_buf_and_resample(transcription_filter_data *gf, + uint64_t &start_timestamp_offset_ns, + uint64_t &end_timestamp_offset_ns) +{ + uint32_t num_frames_from_infos = 0; + + { + // scoped lock the buffer mutex + std::lock_guard lock(gf->whisper_buf_mutex); + + if (gf->input_buffers[0].size == 0) { + return 1; + } + + obs_log(LOG_DEBUG, "segmentation: currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); + + // max number of frames is 10 seconds worth of audio + const size_t max_num_frames = gf->sample_rate * 10; + + // pop all infos from the info buffer and mark the beginning timestamp from the first + // info as the beginning timestamp of the segment + struct transcription_filter_audio_info info_from_buf = {0}; + const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); + while (gf->info_buffer.size >= size_of_audio_info) { + circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); + num_frames_from_infos += info_from_buf.frames; + if (start_timestamp_offset_ns == 0) { + start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; + } + // Check if we're within the needed segment length + if (num_frames_from_infos > max_num_frames) { + // too big, push the last info into the buffer's front where it was + num_frames_from_infos -= info_from_buf.frames; + circlebuf_push_front(&gf->info_buffer, &info_from_buf, + size_of_audio_info); + break; + } + } + // calculate the end timestamp from the info plus the number of frames in the packet + end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns + + info_from_buf.frames * 1000000000 / gf->sample_rate; + + if (start_timestamp_offset_ns > end_timestamp_offset_ns) { + // this may happen when the incoming media has a timestamp reset + // in this case, we should figure out the start timestamp from the end timestamp + // and the number of frames + start_timestamp_offset_ns = + end_timestamp_offset_ns - + num_frames_from_infos * 1000000000 / gf->sample_rate; + } + + for (size_t c = 0; c < gf->channels; c++) { + // zero the rest of copy_buffers + memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); + } + + /* Pop from input circlebuf */ + for (size_t c = 0; c < gf->channels; c++) { + // Push the new data to copy_buffers[c] + circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], + num_frames_from_infos * sizeof(float)); + } + } + + obs_log(LOG_DEBUG, "found %d frames from info buffer.", num_frames_from_infos); + gf->last_num_frames = num_frames_from_infos; + + { + // resample to 16kHz + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; + uint64_t ts_offset; + { + ProfileScope("resample"); + audio_resampler_resample(gf->resampler_to_whisper, + (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, + (const uint8_t **)gf->copy_buffers, + (uint32_t)num_frames_from_infos); + } + + circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], + resampled_16khz_frames * sizeof(float)); + obs_log(LOG_DEBUG, + "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, + gf->resampled_buffer.size); + } + + return 0; +} diff --git a/src/whisper-utils/resample-utils.h b/src/whisper-utils/resample-utils.h new file mode 100644 index 0000000..c2d2872 --- /dev/null +++ b/src/whisper-utils/resample-utils.h @@ -0,0 +1,10 @@ +#ifndef RESAMPLE_UTILS_H +#define RESAMPLE_UTILS_H + +#include "transcription-filter-data.h" + +int get_data_from_buf_and_resample(transcription_filter_data *gf, + uint64_t &start_timestamp_offset_ns, + uint64_t &end_timestamp_offset_ns); + +#endif diff --git a/src/whisper-utils/vad-processing.cpp b/src/whisper-utils/vad-processing.cpp index 0e9c744..d0a9266 100644 --- a/src/whisper-utils/vad-processing.cpp +++ b/src/whisper-utils/vad-processing.cpp @@ -4,107 +4,13 @@ #include "transcription-filter-data.h" #include "vad-processing.h" +#include "resample-utils.h" #ifdef _WIN32 #define NOMINMAX #include #endif -int get_data_from_buf_and_resample(transcription_filter_data *gf, - uint64_t &start_timestamp_offset_ns, - uint64_t &end_timestamp_offset_ns) -{ - uint32_t num_frames_from_infos = 0; - - { - // scoped lock the buffer mutex - std::lock_guard lock(gf->whisper_buf_mutex); - - if (gf->input_buffers[0].size == 0) { - return 1; - } - - obs_log(gf->log_level, - "segmentation: currently %lu bytes in the audio input buffer", - gf->input_buffers[0].size); - - // max number of frames is 10 seconds worth of audio - const size_t max_num_frames = gf->sample_rate * 10; - - // pop all infos from the info buffer and mark the beginning timestamp from the first - // info as the beginning timestamp of the segment - struct transcription_filter_audio_info info_from_buf = {0}; - const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); - while (gf->info_buffer.size >= size_of_audio_info) { - circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); - num_frames_from_infos += info_from_buf.frames; - if (start_timestamp_offset_ns == 0) { - start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; - } - // Check if we're within the needed segment length - if (num_frames_from_infos > max_num_frames) { - // too big, push the last info into the buffer's front where it was - num_frames_from_infos -= info_from_buf.frames; - circlebuf_push_front(&gf->info_buffer, &info_from_buf, - size_of_audio_info); - break; - } - } - // calculate the end timestamp from the info plus the number of frames in the packet - end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns + - info_from_buf.frames * 1000000000 / gf->sample_rate; - - if (start_timestamp_offset_ns > end_timestamp_offset_ns) { - // this may happen when the incoming media has a timestamp reset - // in this case, we should figure out the start timestamp from the end timestamp - // and the number of frames - start_timestamp_offset_ns = - end_timestamp_offset_ns - - num_frames_from_infos * 1000000000 / gf->sample_rate; - } - - for (size_t c = 0; c < gf->channels; c++) { - // zero the rest of copy_buffers - memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); - } - - /* Pop from input circlebuf */ - for (size_t c = 0; c < gf->channels; c++) { - // Push the new data to copy_buffers[c] - circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], - num_frames_from_infos * sizeof(float)); - } - } - - obs_log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); - gf->last_num_frames = num_frames_from_infos; - - { - // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - { - ProfileScope("resample"); - audio_resampler_resample(gf->resampler_to_whisper, - (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, - (uint32_t)num_frames_from_infos); - } - - circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], - resampled_16khz_frames * sizeof(float)); - obs_log(gf->log_level, - "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", - (int)gf->channels, (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, - gf->resampled_buffer.size); - } - - return 0; -} - vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) { // get data from buffer and resample