Skip to content

Commit

Permalink
feat: Add max_sub_duration field to transcription filter data
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Aug 23, 2024
1 parent f4d2cfc commit dcf368d
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 66 deletions.
28 changes: 16 additions & 12 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ void set_text_callback(struct transcription_filter_data *gf,
const DetectionResultWithText &resultIn)
{
DetectionResultWithText result = resultIn;
if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH ||
result.result == DETECTION_RESULT_PARTIAL)) {
gf->last_sub_render_time = now_ms();
gf->cleared_last_sub = false;
}

std::string str_copy = result.text;

Expand Down Expand Up @@ -243,10 +238,12 @@ void set_text_callback(struct transcription_filter_data *gf,
str_copy = translated_sentence;
} else {
if (gf->buffered_output) {
if (result.result == DETECTION_RESULT_SPEECH) {
// buffered output - add the sentence to the monitor
gf->translation_monitor.addSentence(translated_sentence);
}
// buffered output - add the sentence to the monitor
gf->translation_monitor.addSentenceFromStdString(
translated_sentence,
get_time_point_from_ms(result.start_timestamp_ms),
get_time_point_from_ms(result.end_timestamp_ms),
result.result == DETECTION_RESULT_PARTIAL);
} else {
// non-buffered output - send the sentence to the selected source
send_caption_to_source(gf->translation_output, translated_sentence,
Expand All @@ -256,9 +253,10 @@ void set_text_callback(struct transcription_filter_data *gf,
}

if (gf->buffered_output) {
if (result.result == DETECTION_RESULT_SPEECH) {
gf->captions_monitor.addSentence(str_copy);
}
gf->captions_monitor.addSentenceFromStdString(
str_copy, get_time_point_from_ms(result.start_timestamp_ms),
get_time_point_from_ms(result.end_timestamp_ms),
result.result == DETECTION_RESULT_PARTIAL);
} else {
// non-buffered output - send the sentence to the selected source
send_caption_to_source(gf->text_source_name, str_copy, gf);
Expand All @@ -273,6 +271,12 @@ void set_text_callback(struct transcription_filter_data *gf,
result.result == DETECTION_RESULT_SPEECH) {
send_sentence_to_file(gf, result, str_copy, translated_sentence);
}

if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH ||
result.result == DETECTION_RESULT_PARTIAL)) {
gf->last_sub_render_time = now_ms();
gf->cleared_last_sub = false;
}
};

void recording_state_callback(enum obs_frontend_event event, void *data)
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct transcription_filter_data {
size_t sentence_number;
// Minimal subtitle duration in ms
size_t min_sub_duration;
// Maximal subtitle duration in ms
size_t max_sub_duration;
// Last time a subtitle was rendered
uint64_t last_sub_render_time;
bool cleared_last_sub;
Expand Down
29 changes: 28 additions & 1 deletion src/transcription-filter-properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,30 @@ void add_buffered_output_group_properties(obs_properties_t *ppts)
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
obs_property_list_add_int(buffer_type_list, "Character", SEGMENTATION_TOKEN);
obs_property_list_add_int(buffer_type_list, "Word", SEGMENTATION_WORD);
obs_property_list_add_int(buffer_type_list, "Sentence", SEGMENTATION_SENTENCE);
// add callback to the segmentation selection to set default values
obs_property_set_modified_callback(buffer_type_list, [](obs_properties_t *props,

Check failure on line 295 in src/transcription-filter-properties.cpp

View workflow job for this annotation

GitHub Actions / Build Project 🧱 / Build for macOS 🍏 (arm64)

unused parameter 'props' [-Werror,-Wunused-parameter]
obs_property_t *property,
obs_data_t *settings) {
UNUSED_PARAMETER(property);
const int segmentation_type = obs_data_get_int(settings, "buffer_output_type");
// set default values for the number of lines and characters per line
switch (segmentation_type) {
case SEGMENTATION_TOKEN:
obs_data_set_int(settings, "buffer_num_lines", 2);
obs_data_set_int(settings, "buffer_num_chars_per_line", 30);
break;
case SEGMENTATION_WORD:
obs_data_set_int(settings, "buffer_num_lines", 2);
obs_data_set_int(settings, "buffer_num_chars_per_line", 10);
break;
case SEGMENTATION_SENTENCE:
obs_data_set_int(settings, "buffer_num_lines", 2);
obs_data_set_int(settings, "buffer_num_chars_per_line", 2);
break;
}
return true;
});
// add buffer lines parameter
obs_properties_add_int_slider(buffered_output_group, "buffer_num_lines",
MT_("buffer_num_lines"), 1, 5, 1);
Expand All @@ -310,6 +334,8 @@ void add_advanced_group_properties(obs_properties_t *ppts, struct transcription_

obs_properties_add_int_slider(advanced_config_group, "min_sub_duration",
MT_("min_sub_duration"), 1000, 5000, 50);
obs_properties_add_int_slider(advanced_config_group, "max_sub_duration",
MT_("max_sub_duration"), 1000, 5000, 50);
obs_properties_add_float_slider(advanced_config_group, "sentence_psum_accept_thresh",
MT_("sentence_psum_accept_thresh"), 0.0, 1.0, 0.05);

Expand Down Expand Up @@ -535,7 +561,8 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_bool(s, "truncate_output_file", false);
obs_data_set_default_bool(s, "only_while_recording", false);
obs_data_set_default_bool(s, "rename_file_to_match_recording", true);
obs_data_set_default_int(s, "min_sub_duration", 3000);
obs_data_set_default_int(s, "min_sub_duration", 1000);
obs_data_set_default_int(s, "max_sub_duration", 3000);
obs_data_set_default_bool(s, "advanced_settings", false);
obs_data_set_default_bool(s, "translate", false);
obs_data_set_default_string(s, "translate_target_language", "__es__");
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->sentence_number = 1;
gf->process_while_muted = obs_data_get_bool(s, "process_while_muted");
gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration");
gf->max_sub_duration = (int)obs_data_get_int(s, "max_sub_duration");
gf->last_sub_render_time = now_ms();
gf->duration_filter_threshold = (float)obs_data_get_double(s, "duration_filter_threshold");
gf->partial_transcription = obs_data_get_bool(s, "partial_group");
Expand Down Expand Up @@ -432,6 +433,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / MAX_MS_WORK_BUFFER));
gf->last_num_frames = 0;
gf->min_sub_duration = (int)obs_data_get_int(settings, "min_sub_duration");
gf->max_sub_duration = (int)obs_data_get_int(settings, "max_sub_duration");
gf->last_sub_render_time = now_ms();
gf->log_level = (int)obs_data_get_int(settings, "log_level");
gf->save_srt = obs_data_get_bool(settings, "subtitle_save_srt");
Expand Down
92 changes: 70 additions & 22 deletions src/whisper-utils/token-buffer-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include "whisper-utils.h"
#include "transcription-utils.h"

#include <iostream>
#include <sstream>

#include <obs-module.h>

#ifdef _WIN32
Expand Down Expand Up @@ -75,37 +78,74 @@ void TokenBufferThread::log_token_vector(const std::vector<std::string> &tokens)
obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str());
}

void TokenBufferThread::addSentence(const std::string &sentence)
void TokenBufferThread::addSentenceFromStdString(const std::string &sentence,
TokenBufferTimePoint start_time,
TokenBufferTimePoint end_time, bool is_partial)
{
if (sentence.empty()) {
return;
}
#ifdef _WIN32
// on windows convert from multibyte to wide char
int count =
MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0);
std::wstring sentence_ws(count, 0);
TokenBufferString sentence_ws(count, 0);
MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0],
count);
#else
std::string sentence_ws = sentence;
TokenBufferString sentence_ws = sentence;
#endif
// split to characters
std::vector<TokenBufferString> characters;
for (const auto &c : sentence_ws) {
characters.push_back(TokenBufferString(1, c));

TokenBufferSentence sentence_for_add;
sentence_for_add.start_time = start_time;
sentence_for_add.end_time = end_time;

if (this->segmentation == SEGMENTATION_WORD) {
// split the sentence to words
std::vector<TokenBufferString> words;
std::basic_istringstream<TokenBufferString::value_type> iss(sentence_ws);
TokenBufferString word;
while (iss >> word) {
words.push_back(word);
}
// add the words to a sentence
for (const auto &word : words) {
sentence_for_add.tokens.push_back({word, is_partial});
sentence_for_add.tokens.push_back({SPACE, is_partial});
}
} else if (this->segmentation == SEGMENTATION_TOKEN) {
// split to characters
std::vector<TokenBufferString> characters;
for (const auto &c : sentence_ws) {
characters.push_back(TokenBufferString(1, c));
}
// add the characters to a sentece
for (const auto &character : characters) {
sentence_for_add.tokens.push_back({character, is_partial});
}
} else {
// add the whole sentence as a single token
sentence_for_add.tokens.push_back({sentence_ws, is_partial});
sentence_for_add.tokens.push_back({SPACE, is_partial});
}
addSentence(sentence_for_add);
}

std::lock_guard<std::mutex> lock(inputQueueMutex);
void TokenBufferThread::addSentence(const TokenBufferSentence &sentence)
{
std::lock_guard<std::mutex> lock(this->inputQueueMutex);

// add the characters to the inputQueue
for (const auto &character : characters) {
// add the tokens to the inputQueue
for (const auto &character : sentence.tokens) {
inputQueue.push_back(character);
}
inputQueue.push_back(SPACE);
inputQueue.push_back({SPACE, sentence.tokens.back().is_partial});

// add to the contribution queue as well
for (const auto &character : characters) {
for (const auto &character : sentence.tokens) {
contributionQueue.push_back(character);
}
contributionQueue.push_back(SPACE);
contributionQueue.push_back({SPACE, sentence.tokens.back().is_partial});
this->lastContributionTime = std::chrono::steady_clock::now();
}

Expand Down Expand Up @@ -148,7 +188,7 @@ void TokenBufferThread::monitor()
if (this->segmentation == SEGMENTATION_TOKEN) {
// pop tokens until a space is found
while (!presentationQueue.empty() &&
presentationQueue.front() != SPACE) {
presentationQueue.front().token != SPACE) {
presentationQueue.pop_front();
}
}
Expand All @@ -158,6 +198,13 @@ void TokenBufferThread::monitor()
std::lock_guard<std::mutex> lock(inputQueueMutex);

if (!inputQueue.empty()) {
// if the input on the inputQueue is partial - first remove all partials
// from the end of the presentation queue
while (!presentationQueue.empty() &&
presentationQueue.back().is_partial) {
presentationQueue.pop_back();
}

// if there are token on the input queue
// then add to the presentation queue based on the segmentation
if (this->segmentation == SEGMENTATION_SENTENCE) {
Expand All @@ -171,16 +218,17 @@ void TokenBufferThread::monitor()
presentationQueue.push_back(inputQueue.front());
inputQueue.pop_front();
} else {
// SEGMENTATION_WORD
// skip spaces in the beginning of the input queue
while (!inputQueue.empty() &&
inputQueue.front() == SPACE) {
inputQueue.front().token == SPACE) {
inputQueue.pop_front();
}
// add one word to the presentation queue
TokenBufferString word;
TokenBufferToken word;
while (!inputQueue.empty() &&
inputQueue.front() != SPACE) {
word += inputQueue.front();
inputQueue.front().token != SPACE) {
word = inputQueue.front();
inputQueue.pop_front();
}
presentationQueue.push_back(word);
Expand All @@ -200,7 +248,7 @@ void TokenBufferThread::monitor()
size_t wordsInSentence = 0;
for (size_t i = 0; i < presentationQueue.size(); i++) {
const auto &word = presentationQueue[i];
sentences.back() += word + SPACE;
sentences.back() += word.token + SPACE;
wordsInSentence++;
if (wordsInSentence == this->numPerSentence) {
sentences.push_back(TokenBufferString());
Expand All @@ -211,12 +259,12 @@ void TokenBufferThread::monitor()
for (size_t i = 0; i < presentationQueue.size(); i++) {
const auto &token = presentationQueue[i];
// skip spaces in the beginning of a sentence (tokensInSentence == 0)
if (token == SPACE &&
if (token.token == SPACE &&
sentences.back().length() == 0) {
continue;
}

sentences.back() += token;
sentences.back() += token.token;
if (sentences.back().length() ==
this->numPerSentence) {
// if the next character is not a space - this is a broken word
Expand Down Expand Up @@ -280,7 +328,7 @@ void TokenBufferThread::monitor()
// take the contribution queue and send it to the output
TokenBufferString contribution;
for (const auto &token : contributionQueue) {
contribution += token;
contribution += token.token;
}
contributionQueue.clear();
#ifdef _WIN32
Expand Down
28 changes: 24 additions & 4 deletions src/whisper-utils/token-buffer-thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#ifdef _WIN32
typedef std::wstring TokenBufferString;
typedef wchar_t TokenBufferChar;
#else
typedef std::string TokenBufferString;
typedef char TokenBufferChar;
#endif

struct transcription_filter_data;
Expand All @@ -27,6 +29,22 @@ enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST };

typedef std::chrono::time_point<std::chrono::steady_clock> TokenBufferTimePoint;

inline std::chrono::time_point<std::chrono::steady_clock> get_time_point_from_ms(uint64_t ms)
{
return std::chrono::time_point<std::chrono::steady_clock>(std::chrono::milliseconds(ms));
}

struct TokenBufferToken {
TokenBufferString token;
bool is_partial;
};

struct TokenBufferSentence {
std::vector<TokenBufferToken> tokens;
TokenBufferTimePoint start_time;
TokenBufferTimePoint end_time;
};

class TokenBufferThread {
public:
// default constructor
Expand All @@ -40,7 +58,9 @@ class TokenBufferThread {
std::chrono::seconds maxTime_,
TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN);

void addSentence(const std::string &sentence);
void addSentenceFromStdString(const std::string &sentence, TokenBufferTimePoint start_time,
TokenBufferTimePoint end_time, bool is_partial = false);
void addSentence(const TokenBufferSentence &sentence);
void clear();
void stopThread();

Expand All @@ -59,9 +79,9 @@ class TokenBufferThread {
void log_token_vector(const std::vector<std::string> &tokens);
int getWaitTime(TokenBufferSpeed speed) const;
struct transcription_filter_data *gf;
std::deque<TokenBufferString> inputQueue;
std::deque<TokenBufferString> presentationQueue;
std::deque<TokenBufferString> contributionQueue;
std::deque<TokenBufferToken> inputQueue;
std::deque<TokenBufferToken> presentationQueue;
std::deque<TokenBufferToken> contributionQueue;
std::thread workerThread;
std::mutex inputQueueMutex;
std::mutex presentationQueueMutex;
Expand Down
Loading

0 comments on commit dcf368d

Please sign in to comment.