Skip to content

Commit

Permalink
refactor: Add target SPM loading and decoding logic in translation mo… (
Browse files Browse the repository at this point in the history
#149)

* refactor: Add target SPM loading and decoding logic in translation module

* refactor: Update target SPM loading error handling in translation module
  • Loading branch information
royshil authored Aug 3, 2024
1 parent 0592fa7 commit 7707af0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/translation/translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ int build_translation_context(struct translation_context &translation_ctx)
obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str());
// find the SPM file in the model folder
std::string local_spm_path = find_file_in_folder_by_regex_expression(
local_model_path, "(sentencepiece|spm|spiece).*?\\.model");
local_model_path, "(sentencepiece|spm|spiece|source).*?\\.(model|spm)");
std::string target_spm_path =
find_file_in_folder_by_regex_expression(local_model_path, "target.*?\\.spm");

try {
obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str());
Expand All @@ -42,6 +44,22 @@ int build_translation_context(struct translation_context &translation_ctx)
return OBS_POLYGLOT_TRANSLATION_INIT_FAIL;
}

if (!target_spm_path.empty()) {
obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str());
translation_ctx.target_processor.reset(
new sentencepiece::SentencePieceProcessor());
const auto target_status =
translation_ctx.target_processor->Load(target_spm_path);
if (!target_status.ok()) {
obs_log(LOG_ERROR, "Failed to load target SPM: %s",
target_status.ToString().c_str());
return OBS_POLYGLOT_TRANSLATION_INIT_FAIL;
}
} else {
obs_log(LOG_INFO, "Target SPM not found, using source SPM for target");
translation_ctx.target_processor.release();
}

translation_ctx.tokenizer = [&translation_ctx](const std::string &text) {
std::vector<std::string> tokens;
translation_ctx.processor->Encode(text, &tokens);
Expand All @@ -50,7 +68,11 @@ int build_translation_context(struct translation_context &translation_ctx)
translation_ctx.detokenizer =
[&translation_ctx](const std::vector<std::string> &tokens) {
std::string text;
translation_ctx.processor->Decode(tokens, &text);
if (translation_ctx.target_processor) {
translation_ctx.target_processor->Decode(tokens, &text);
} else {
translation_ctx.processor->Decode(tokens, &text);
}
return std::regex_replace(text, std::regex("<unk>"), "UNK");
};

Expand Down
1 change: 1 addition & 0 deletions src/translation/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SentencePieceProcessor;
struct translation_context {
std::string local_model_folder_path;
std::unique_ptr<sentencepiece::SentencePieceProcessor> processor;
std::unique_ptr<sentencepiece::SentencePieceProcessor> target_processor;
std::unique_ptr<ctranslate2::Translator> translator;
std::unique_ptr<ctranslate2::TranslationOptions> options;
std::function<std::vector<std::string>(const std::string &)> tokenizer;
Expand Down

0 comments on commit 7707af0

Please sign in to comment.