Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update whisper model path handling in transcription filter #117

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/model-utils/model-downloader-ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ ModelDownloader::~ModelDownloader()
}
delete this->download_thread;
}
delete this->download_worker;
if (this->download_worker != nullptr) {
delete this->download_worker;
}
}

ModelDownloadWorker::~ModelDownloadWorker()
Expand Down
5 changes: 1 addition & 4 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ void enable_callback(void *data_, calldata_t *cd)
obs_log(gf_->log_level, "enable_callback: enable");
gf_->active = true;
reset_caption_state(gf_);
// get filter settings from gf_->context
obs_data_t *settings = obs_source_get_settings(gf_->context);
update_whisper_model(gf_, settings);
obs_data_release(settings);
update_whisper_model(gf_);
} else {
obs_log(gf_->log_level, "enable_callback: disable");
gf_->active = false;
Expand Down
28 changes: 17 additions & 11 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level");
gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream");
Expand Down Expand Up @@ -216,9 +216,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
}
} else {
obs_log(LOG_INFO, "buffered_output disable");
obs_log(gf->log_level, "buffered_output disable");
if (gf->buffered_output) {
obs_log(LOG_INFO, "buffered_output currently enabled, disabling");
obs_log(gf->log_level, "buffered_output currently enabled, disabling");
if (gf->captions_monitor.isEnabled()) {
gf->captions_monitor.clear();
gf->captions_monitor.stopThread();
Expand Down Expand Up @@ -312,8 +312,12 @@ void transcription_filter_update(void *data, obs_data_t *s)
obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language =
language_codes_2_reverse[gf->target_lang].c_str();
if (language_codes_2_reverse.count(gf->target_lang) > 0) {
gf->whisper_params.language =
language_codes_2_reverse[gf->target_lang].c_str();
} else {
gf->whisper_params.language = "auto";
}
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
Expand Down Expand Up @@ -345,11 +349,11 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
}

if (gf->initial_creation && obs_source_enabled(gf->context)) {
if (gf->initial_creation && gf->context != nullptr && obs_source_enabled(gf->context)) {
obs_log(LOG_INFO, "Initial filter creation and source enabled");

// source was enabled on creation
obs_data_t *settings = obs_source_get_settings(gf->context);
update_whisper_model(gf, settings);
obs_data_release(settings);
update_whisper_model(gf);
gf->active = true;
gf->initial_creation = false;
}
Expand Down Expand Up @@ -413,7 +417,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
strcmp(subtitle_sources, "(null)") == 0 || strlen(subtitle_sources) == 0) {
obs_log(LOG_INFO, "create text source");
obs_log(gf->log_level, "Create text source");
// check if a source called "LocalVocal Subtitles" exists
obs_source_t *source = obs_get_source_by_name("LocalVocal Subtitles");
if (source) {
Expand Down Expand Up @@ -497,7 +501,7 @@ void transcription_filter_hide(void *data)

void transcription_filter_defaults(obs_data_t *s)
{
obs_log(LOG_INFO, "filter defaults");
obs_log(LOG_DEBUG, "filter defaults");

obs_data_set_default_bool(s, "buffered_output", false);
obs_data_set_default_int(s, "buffer_num_lines", 2);
Expand Down Expand Up @@ -658,6 +662,8 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_string(prop_lang, lang.second.c_str(),
lang.first.c_str());
}
// set the language to auto (default)
obs_data_set_string(settings, "whisper_language_select", "auto");
}
}
return true;
Expand Down
28 changes: 21 additions & 7 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,28 @@
#include "plugin-support.h"
#include "model-utils/model-downloader.h"

void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
void update_whisper_model(struct transcription_filter_data *gf)
{
// update the whisper model path
if (gf->context == nullptr) {
obs_log(LOG_ERROR, "obs_source_t context is null");
return;
}

obs_data_t *s = obs_source_get_settings(gf->context);
if (s == nullptr) {
obs_log(LOG_ERROR, "obs_data_t settings is null");
return;
}

// Get settings from context
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
obs_data_release(s);

// update the whisper model path

const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;

char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
Expand Down Expand Up @@ -73,8 +91,6 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
Expand All @@ -98,13 +114,11 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
gf->whisper_model_path.c_str(), new_model_path.c_str());
}

const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");

if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
// dtw_token_timestamps changed
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
gf->enable_token_ts_dtw = new_dtw_timestamps;
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path,
silero_vad_model_file_str.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/whisper-utils/whisper-model-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

#include "transcription-filter-data.h"

void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
void update_whisper_model(struct transcription_filter_data *gf);

#endif // WHISPER_MODEL_UTILS_H
4 changes: 2 additions & 2 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ void whisper_loop(void *data)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

obs_log(LOG_INFO, "starting whisper thread");
obs_log(gf->log_level, "Starting whisper thread");

vad_state current_vad_state = {false, 0, 0};
// 500 ms worth of audio is needed for VAD segmentation
Expand Down Expand Up @@ -511,5 +511,5 @@ void whisper_loop(void *data)
gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(50));
}

obs_log(LOG_INFO, "exiting whisper thread");
obs_log(gf->log_level, "Exiting whisper thread");
}
Loading