Skip to content

Commit

Permalink
fix whisper model loading language (#139)
Browse files Browse the repository at this point in the history
* refactor: Add boolean flag for whisper model loaded status

* refactor: Improve handling of whisper model paths in transcription filter

* refactor: Update whisper model path and add flag for model loaded status
  • Loading branch information
royshil authored Jul 17, 2024
1 parent 44f072b commit 4e3fdcd
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct transcription_filter_data {
// Output file path to write the subtitles
std::string output_file_path;
std::string whisper_model_file_currently_loaded;
bool whisper_model_loaded_new;

// Use std for thread and mutex
std::thread whisper_thread;
Expand Down
60 changes: 35 additions & 25 deletions src/transcription-filter-properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ bool file_output_select_changed(obs_properties_t *props, obs_property_t *propert
return true;
}

bool external_model_file_selection(obs_properties_t *props, obs_property_t *property,
bool external_model_file_selection(void *data_, obs_properties_t *props, obs_property_t *property,
obs_data_t *settings)
{
UNUSED_PARAMETER(property);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
// If the selected model is the external model, show the external model file selection
// input
const char *new_model_path = obs_data_get_string(settings, "whisper_model_path");
const bool is_external = strcmp(new_model_path, "!!!external!!!") == 0;
const char *new_model_path_cstr =
obs_data_get_string(settings, "whisper_model_path") != nullptr
? obs_data_get_string(settings, "whisper_model_path")
: "";
const std::string new_model_path = new_model_path_cstr;
const bool is_external = (new_model_path.find("!!!external!!!") != std::string::npos);
if (is_external) {
obs_property_set_visible(obs_properties_get(props, "whisper_model_path_external"),
true);
Expand All @@ -85,26 +91,29 @@ bool external_model_file_selection(obs_properties_t *props, obs_property_t *prop
false);
}

const std::string model_name = new_model_path;
// if the model is english-only -> hide all the languages but english
const bool is_english_only_internal = (model_name.find("English") != std::string::npos) &&
!is_external;
// clear the language selection list ("whisper_language_select")
obs_property_t *prop_lang = obs_properties_get(props, "whisper_language_select");
obs_property_list_clear(prop_lang);
if (is_english_only_internal) {
// add only the english language
obs_property_list_add_string(prop_lang, "English", "en");
// set the language to english
obs_data_set_string(settings, "whisper_language_select", "en");
} else {
// add all the languages
for (const auto &lang : whisper_available_lang) {
obs_property_list_add_string(prop_lang, lang.second.c_str(),
lang.first.c_str());
// check if this is a new model selection
if (gf_->whisper_model_loaded_new) {
// if the model is english-only -> hide all the languages but english
const bool is_english_only_internal =
(new_model_path.find("English") != std::string::npos) && !is_external;
// clear the language selection list ("whisper_language_select")
obs_property_t *prop_lang = obs_properties_get(props, "whisper_language_select");
obs_property_list_clear(prop_lang);
if (is_english_only_internal) {
// add only the english language
obs_property_list_add_string(prop_lang, "English", "en");
// set the language to english
obs_data_set_string(settings, "whisper_language_select", "en");
} else {
// add all the languages
for (const auto &lang : whisper_available_lang) {
obs_property_list_add_string(prop_lang, lang.second.c_str(),
lang.first.c_str());
}
// set the language to auto (default)
obs_data_set_string(settings, "whisper_language_select", "auto");
}
// set the language to auto (default)
obs_data_set_string(settings, "whisper_language_select", "auto");
gf_->whisper_model_loaded_new = false;
}
return true;
}
Expand All @@ -131,7 +140,8 @@ bool translation_external_model_selection(obs_properties_t *props, obs_property_
return true;
}

void add_transcription_group_properties(obs_properties_t *ppts)
void add_transcription_group_properties(obs_properties_t *ppts,
struct transcription_filter_data *gf)
{
// add "Transcription" group
obs_properties_t *transcription_group = obs_properties_create();
Expand Down Expand Up @@ -159,7 +169,7 @@ void add_transcription_group_properties(obs_properties_t *ppts)
obs_property_set_visible(obs_properties_get(ppts, "whisper_model_path_external"), false);

// Add a callback to the model list to handle the external model file selection
obs_property_set_modified_callback(whisper_models_list, external_model_file_selection);
obs_property_set_modified_callback2(whisper_models_list, external_model_file_selection, gf);
}

void add_translation_group_properties(obs_properties_t *ppts)
Expand Down Expand Up @@ -474,7 +484,7 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_set_modified_callback(advanced_settings, advanced_settings_callback);

add_general_group_properties(ppts);
add_transcription_group_properties(ppts);
add_transcription_group_properties(ppts, gf);
add_translation_group_properties(ppts);
add_file_output_group_properties(ppts);
add_buffered_output_group_properties(ppts);
Expand Down
2 changes: 2 additions & 0 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void update_whisper_model(struct transcription_filter_data *gf)
// model path changed
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());

gf->whisper_model_loaded_new = true;
}

// check if the new model is external file
Expand Down

0 comments on commit 4e3fdcd

Please sign in to comment.