diff --git a/CMakeLists.txt b/CMakeLists.txt index e64f45c..9233158 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,6 +123,7 @@ target_sources( src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp + src/ui/filter-replace-utils.cpp src/translation/translation-language-utils.cpp src/ui/filter-replace-dialog.cpp) @@ -147,6 +148,7 @@ if(ENABLE_TESTS) src/whisper-utils/vad-processing.cpp src/translation/language_codes.cpp src/translation/translation.cpp + src/ui/filter-replace-utils.cpp src/translation/translation-language-utils.cpp) find_libav(${CMAKE_PROJECT_NAME}-tests) diff --git a/src/tests/README.md b/src/tests/README.md index e47f730..ed9e52f 100644 --- a/src/tests/README.md +++ b/src/tests/README.md @@ -112,7 +112,7 @@ The JSON config file can look e.g. like "silero_vad_model_file": ".../obs-localvocal/data/models/silero-vad/silero_vad.onnx", "ct2_model_folder": ".../obs-localvocal/models/m2m-100-418M", "fix_utf8": true, - "suppress_sentences": "다음 영상에서 만나요!\nMBC 뉴스 김지경입니다\nMBC 뉴스 김성현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요는 저에게 아주 큰\n다음 영상에서 만나요\n끝까지 시청해주셔서 감사합니다\n구독과 좋아요 부탁드립니다!\nMBC 뉴스 이준범입니다\nMBC 뉴스 문재인입니다\nMBC 뉴스 김지연입니다\nMBC 뉴스 안영백입니다.\nMBC 뉴스 이덕영입니다\nMBC 뉴스 김상현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요 부탁드", + "filter_words_replace": "[{\"key\": \"다음 영상에서 만나요!\", \"value\":\"\"}]", "overlap_ms": 150, "log_level": "debug", "whisper_sampling_method": 0 diff --git a/src/tests/evaluate_output.py b/src/tests/evaluate_output.py index 873ce25..e4ae59e 100644 --- a/src/tests/evaluate_output.py +++ b/src/tests/evaluate_output.py @@ -1,33 +1,80 @@ import Levenshtein import argparse -from diff_match_patch import diff_match_patch +import unicodedata +import re +import difflib -def visualize_differences(ref_text, hyp_text): - dmp = diff_match_patch() - diffs = dmp.diff_main(hyp_text, ref_text, checklines=True) - html = dmp.diff_prettyHtml(diffs) - return html +def remove_accents(text): + return ''.join(c for c in unicodedata.normalize('NFD', text) + if unicodedata.category(c) != 'Mn') -def calculate_wer(ref_text, hyp_text): - ref_words = ref_text.split() - hyp_words = hyp_text.split() +def clean_text(text): + # Remove punctuation and special characters + text = re.sub(r'[^\w\s]', '', text) + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + return text - distance = Levenshtein.distance(ref_words, hyp_words) - wer = distance / len(ref_words) +def normalize_spanish_gender_postfixes(text): + # Normalize + text = re.sub(r'\b(\w+?)(a)\b', r'\1e', text) + return text + +def tokenize(text, should_remove_accents=False, remove_punctuation=False): + # Convert to lowercase, remove accents, clean text, and split + if should_remove_accents: + text = remove_accents(text) + text = normalize_spanish_gender_postfixes(text) + if remove_punctuation: + text = clean_text(text) + tokens = text.lower().split() + return tokens + +def calculate_wer(ref_text_tokens, hyp_text_tokens): + distance = Levenshtein.distance(ref_text_tokens, hyp_text_tokens, weights=(1, 1, 1)) + wer = distance / max(len(ref_text_tokens), len(hyp_text_tokens)) return wer -def calculate_cer(ref_text, hyp_text): +def calculate_cer(ref_text_tokens, hyp_text_tokens): + # Join tokens into a single string + ref_text = ' '.join(ref_text_tokens) + hyp_text = ' '.join(hyp_text_tokens) distance = Levenshtein.distance(ref_text, hyp_text) cer = distance / len(ref_text) return cer -def compare_tokens(ref_tokens, hyp_tokens): - comparisons = [] - for ref_token, hyp_token in zip(ref_tokens, hyp_tokens): - distance = Levenshtein.distance(ref_token, hyp_token) - comparison = {'ref_token': ref_token, 'hyp_token': hyp_token, 'error_rate': distance / len(ref_token)} - comparisons.append(comparison) - return comparisons +def print_alignment(ref_words, hyp_words): + d = difflib.Differ() + diff = list(d.compare(ref_words, hyp_words)) + + print("\nToken-by-token alignment:") + print("Reference | Hypothesis") + print("-" * 30) + + ref_token = hyp_token = "" + for token in diff: + if token.startswith(" "): # Common token + if ref_token or hyp_token: + print(f"{ref_token:<10} | {hyp_token:<10}") + ref_token = hyp_token = "" + print(f"{token[2:]:<10} | {token[2:]:<10}") + elif token.startswith("- "): # Token in reference, not in hypothesis + ref_token = token[2:] + elif token.startswith("+ "): # Token in hypothesis, not in reference + hyp_token = token[2:] + if ref_token: + print(f"{ref_token:<10} | {hyp_token:<10} (Substitution)") + ref_token = hyp_token = "" + else: + print(f"{"":10} | {hyp_token:<10} (Insertion)") + hyp_token = "" + + # Print any remaining tokens + if ref_token: + print(f"{ref_token:<10} | {"":10} (Deletion)") + elif hyp_token: + print(f"{"":10} | {hyp_token:<10} (Insertion)") + def read_text_from_file(file_path, join_sentences=True): with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: @@ -41,28 +88,27 @@ def read_text_from_file(file_path, join_sentences=True): parser = argparse.ArgumentParser(description='Evaluate output') parser.add_argument('ref_file_path', type=str, help='Path to the reference file') parser.add_argument('hyp_file_path', type=str, help='Path to the hypothesis file') +parser.add_argument('--remove_accents', action='store_true', help='Remove accents from text') +parser.add_argument('--remove_punctuation', action='store_true', help='Remove punctuation from text') +parser.add_argument('--print_alignment', action='store_true', help='Print the alignment to the console') +parser.add_argument('--write_tokens', action='store_true', help='Write the tokens to a file') args = parser.parse_args() -ref_text = read_text_from_file(args.ref_file_path) -hyp_text = read_text_from_file(args.hyp_file_path) -wer = calculate_wer(ref_text, hyp_text) -cer = calculate_cer(ref_text, hyp_text) -print("Word Error Rate (WER):", wer) -print("Character Error Rate (CER):", cer) - -ref_text = '\n'.join(read_text_from_file(args.ref_file_path, join_sentences=False)) -hyp_text = '\n'.join(read_text_from_file(args.hyp_file_path, join_sentences=False)) -html_diff = visualize_differences(ref_text, hyp_text) -with open("diff_visualization.html", "w", encoding="utf-8") as file: - file.write(html_diff) +ref_text = read_text_from_file(args.ref_file_path, join_sentences=True) +hyp_text = read_text_from_file(args.hyp_file_path, join_sentences=True) +ref_tokens = tokenize(ref_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation) +hyp_tokens = tokenize(hyp_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation) -from Bio.Align import PairwiseAligner +if args.print_alignment: + print_alignment(ref_tokens, hyp_tokens) -aligner = PairwiseAligner() +if args.write_tokens: + with open("ref_tokens.txt", "w", encoding="utf-8") as file: + file.write('\n'.join(ref_tokens)) + with open("hyp_tokens.txt", "w", encoding="utf-8") as file: + file.write('\n'.join(hyp_tokens)) -alignments = aligner.align(ref_text, hyp_text) +wer = calculate_wer(ref_tokens, hyp_tokens) -# write the first alignment to a file -with open("alignment.txt", "w", encoding="utf-8") as file: - file.write(alignments[0].format()) +print(f"\"{args.ref_file_path}\" WER: \"{wer:.2}\"") diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index ee936af..3c0f4a4 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -20,6 +20,7 @@ #include "whisper-utils/vad-processing.h" #include "audio-file-utils.h" #include "translation/language_codes.h" +#include "ui/filter-replace-utils.h" #include #include @@ -430,6 +431,12 @@ int wmain(int argc, wchar_t *argv[]) config["no_context"] ? "true" : "false"); gf->whisper_params.no_context = config["no_context"]; } + if (config.contains("filter_words_replace")) { + obs_log(LOG_INFO, "Setting filter_words_replace to %s", + config["filter_words_replace"]); + gf->filter_words_replace = deserialize_filter_words_replace( + config["filter_words_replace"]); + } // set log level if (logLevelStr == "debug") { gf->log_level = LOG_DEBUG; diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 4a3693f..815b9e3 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -11,6 +11,7 @@ #include "model-utils/model-downloader-types.h" #include "translation/language_codes.h" #include "ui/filter-replace-dialog.h" +#include "ui/filter-replace-utils.h" #include #include diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 657fea6..90e67eb 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -30,6 +30,7 @@ #include "translation/translation.h" #include "translation/translation-includes.h" #include "ui/filter-replace-dialog.h" +#include "ui/filter-replace-utils.h" void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) { diff --git a/src/ui/filter-replace-dialog.cpp b/src/ui/filter-replace-dialog.cpp index 1464082..b491a31 100644 --- a/src/ui/filter-replace-dialog.cpp +++ b/src/ui/filter-replace-dialog.cpp @@ -73,32 +73,3 @@ void FilterReplaceDialog::editFilter(QTableWidgetItem *item) // use the row number to update the filter_words_replace map ctx->filter_words_replace[item->row()] = std::make_tuple(key, value); } - -std::string serialize_filter_words_replace( - const std::vector> &filter_words_replace) -{ - if (filter_words_replace.empty()) { - return "[]"; - } - // use JSON to serialize the filter_words_replace map - nlohmann::json j; - for (const auto &entry : filter_words_replace) { - j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}}); - } - return j.dump(); -} - -std::vector> -deserialize_filter_words_replace(const std::string &filter_words_replace_str) -{ - if (filter_words_replace_str.empty()) { - return {}; - } - // use JSON to deserialize the filter_words_replace map - std::vector> filter_words_replace; - nlohmann::json j = nlohmann::json::parse(filter_words_replace_str); - for (const auto &entry : j) { - filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"])); - } - return filter_words_replace; -} diff --git a/src/ui/filter-replace-dialog.h b/src/ui/filter-replace-dialog.h index c605531..d392a80 100644 --- a/src/ui/filter-replace-dialog.h +++ b/src/ui/filter-replace-dialog.h @@ -27,9 +27,4 @@ private slots: void editFilter(QTableWidgetItem *item); }; -std::string serialize_filter_words_replace( - const std::vector> &filter_words_replace); -std::vector> -deserialize_filter_words_replace(const std::string &filter_words_replace_str); - #endif // FILTERREPLACEDIALOG_H diff --git a/src/ui/filter-replace-utils.cpp b/src/ui/filter-replace-utils.cpp new file mode 100644 index 0000000..14af016 --- /dev/null +++ b/src/ui/filter-replace-utils.cpp @@ -0,0 +1,32 @@ +#include "filter-replace-utils.h" + +#include + +std::string serialize_filter_words_replace( + const std::vector> &filter_words_replace) +{ + if (filter_words_replace.empty()) { + return "[]"; + } + // use JSON to serialize the filter_words_replace map + nlohmann::json j; + for (const auto &entry : filter_words_replace) { + j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}}); + } + return j.dump(); +} + +std::vector> +deserialize_filter_words_replace(const std::string &filter_words_replace_str) +{ + if (filter_words_replace_str.empty()) { + return {}; + } + // use JSON to deserialize the filter_words_replace map + std::vector> filter_words_replace; + nlohmann::json j = nlohmann::json::parse(filter_words_replace_str); + for (const auto &entry : j) { + filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"])); + } + return filter_words_replace; +} diff --git a/src/ui/filter-replace-utils.h b/src/ui/filter-replace-utils.h new file mode 100644 index 0000000..485b968 --- /dev/null +++ b/src/ui/filter-replace-utils.h @@ -0,0 +1,12 @@ +#ifndef FILTER_REPLACE_UTILS_H +#define FILTER_REPLACE_UTILS_H + +#include +#include + +std::string serialize_filter_words_replace( + const std::vector> &filter_words_replace); +std::vector> +deserialize_filter_words_replace(const std::string &filter_words_replace_str); + +#endif /* FILTER_REPLACE_UTILS_H */