Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More offline test improvements #153

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ void obs_log(int log_level, const char *format, ...)

auto diff = now - start;

static std::mutex log_mutex;
auto lock = std::lock_guard(log_mutex);
// print timestamp
printf("[%02d:%02d:%02d.%03d] [%02d:%02lld.%03lld] ", now_tm.tm_hour, now_tm.tm_min,
now_tm.tm_sec, (int)(epoch.count() % 1000),
Expand Down Expand Up @@ -194,6 +196,11 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
return gf;
}

std::mutex json_segments_input_mutex;
std::condition_variable json_segments_input_cv;
std::vector<nlohmann::json> json_segments_input;
bool json_segments_input_finished = false;

void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data,
size_t frames, int vad_state, const DetectionResultWithText &result)
{
Expand All @@ -214,33 +221,56 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm
// obs_log(gf->log_level, "Saving %lu frames to %s", frames, filename.c_str());
// write_audio_wav_file(filename.c_str(), pcm32f_data, frames);

// append a row to the array in the segments.json file
std::string segments_filename = "segments.json";
nlohmann::json segments_json;

// Read existing segments from file
std::ifstream segments_file(segments_filename);
if (segments_file.is_open()) {
segments_file >> segments_json;
segments_file.close();
}

// Create a new segment object
nlohmann::json segment;
segment["start_time"] = result.start_timestamp_ms / 1000.0;
segment["end_time"] = result.end_timestamp_ms / 1000.0;
segment["segment_label"] = result.text;

// Add the new segment to the segments array
segments_json.push_back(segment);
{
auto lock = std::lock_guard(json_segments_input_mutex);

// Add the new segment to the segments array
json_segments_input.push_back(segment);
}
json_segments_input_cv.notify_one();
}

void json_segments_saver_thread_function()
{
std::string segments_filename = "segments.json";
nlohmann::json segments_json;

decltype(json_segments_input) json_segments_input_local;

for (;;) {
{
auto lock = std::unique_lock(json_segments_input_mutex);
while (json_segments_input.empty()) {
if (json_segments_input_finished)
return;
json_segments_input_cv.wait(lock, [&] {
return json_segments_input_finished ||
!json_segments_input.empty();
});
}

std::swap(json_segments_input, json_segments_input_local);
json_segments_input.clear();
}

for (auto &elem : json_segments_input_local) {
segments_json.push_back(std::move(elem));
}

// Write the updated segments back to the file
std::ofstream segments_file_out(segments_filename);
if (segments_file_out.is_open()) {
segments_file_out << std::setw(4) << segments_json << std::endl;
segments_file_out.close();
} else {
obs_log(gf->log_level, "Failed to open %s", segments_filename.c_str());
// Write the updated segments back to the file
std::ofstream segments_file_out(segments_filename);
if (segments_file_out.is_open()) {
segments_file_out << std::setw(4) << segments_json << std::endl;
segments_file_out.close();
} else {
obs_log(LOG_INFO, "Failed to open %s", segments_filename.c_str());
}
}
}

Expand Down Expand Up @@ -361,6 +391,7 @@ int wmain(int argc, wchar_t *argv[])

std::cout << "LocalVocal Offline Test" << std::endl;
transcription_filter_data *gf = nullptr;
std::optional<std::thread> audio_chunk_saver_thread;

std::vector<std::vector<uint8_t>> audio =
read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) {
Expand Down Expand Up @@ -419,6 +450,10 @@ int wmain(int argc, wchar_t *argv[])
return 1;
}

if (gf->enable_audio_chunks_callback) {
audio_chunk_saver_thread.emplace(json_segments_saver_thread_function);
}

// truncate the output file
obs_log(LOG_INFO, "Truncating output file");
std::ofstream output_file(gf->output_file_path, std::ios::trunc);
Expand All @@ -437,10 +472,10 @@ int wmain(int argc, wchar_t *argv[])

obs_log(LOG_INFO, "Sending samples to whisper buffer");
// 25 ms worth of frames
int frames = gf->sample_rate * window_size_in_ms.count() / 1000;
size_t frames = gf->sample_rate * window_size_in_ms.count() / 1000;
const int frame_size_bytes = sizeof(float);
int frames_size_bytes = frames * frame_size_bytes;
int frames_count = 0;
size_t frames_size_bytes = frames * frame_size_bytes;
size_t frames_count = 0;
int64_t start_time = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
Expand All @@ -464,12 +499,13 @@ int wmain(int argc, wchar_t *argv[])
if (false && now > max_wait)
break;

if (gf->input_buffers->size == 0)
break;

gf->input_cv->wait_for(
lock, std::chrono::milliseconds(10), [&] {
lock, std::chrono::milliseconds(1), [&] {
return gf->input_buffers->size == 0;
});
if (gf->input_buffers->size == 0)
break;
}
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
Expand Down Expand Up @@ -533,6 +569,15 @@ int wmain(int argc, wchar_t *argv[])
}
}

if (audio_chunk_saver_thread.has_value()) {
{
auto lock = std::lock_guard(json_segments_input_mutex);
json_segments_input_finished = true;
}
json_segments_input_cv.notify_one();
audio_chunk_saver_thread->join();
}

release_context(gf);

obs_log(LOG_INFO, "LocalVocal Offline Test Done");
Expand Down
Loading