diff --git a/app.py b/app.py index a81d82d..44d0168 100644 --- a/app.py +++ b/app.py @@ -115,9 +115,14 @@ def launch(self): visible=self.args.colab, value="") cb_include_subdirectory = gr.Checkbox(label="Include Subdirectory Files", - info="When using Input Folder Path above, whether to include all files in the subdirectory or not", + info="When using Input Folder Path above, whether to include all files in the subdirectory or not.", visible=self.args.colab, value=False) + cb_save_same_dir = gr.Checkbox(label="Save outputs at same directory", + info="When using Input Folder Path above, whether to save output in the same directory as inputs or not, in addition to the original" + " output directory.", + visible=self.args.colab, + value=True) pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): @@ -127,9 +132,11 @@ def launch(self): files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3, interactive=False) btn_openfolder = gr.Button('📂', scale=1) - params = [input_file, tb_input_folder, cb_include_subdirectory, dd_file_format, cb_timestamp] + params = [input_file, tb_input_folder, cb_include_subdirectory, cb_save_same_dir, + dd_file_format, cb_timestamp] + params = params + pipeline_params btn_run.click(fn=self.whisper_inf.transcribe_file, - inputs=params + pipeline_params, + inputs=params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml index 89ea204..4a7172f 100644 --- a/configs/default_parameters.yaml +++ b/configs/default_parameters.yaml @@ -30,7 +30,7 @@ whisper: hotwords: null language_detection_threshold: 0.5 language_detection_segments: 1 - add_timestamp: true + add_timestamp: false vad: vad_filter: false @@ -62,4 +62,4 @@ translation: source_lang: null target_lang: null max_length: 200 - add_timestamp: true + add_timestamp: false diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 8172a25..0d03608 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -185,6 +185,7 @@ def transcribe_file(self, files: Optional[List] = None, input_folder_path: Optional[str] = None, include_subdirectory: Optional[str] = None, + save_same_dir: Optional[str] = None, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), @@ -201,7 +202,11 @@ def transcribe_file(self, Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and this will be used instead. include_subdirectory: Optional[str] - When using Input Folder Path above, whether to include all files in the subdirectory or not + When using `input_folder_path`, whether to include all files in the subdirectory or not + save_same_dir: Optional[str] + When using `input_folder_path`, whether to save output in the same directory as inputs or not, in addition + to the original output directory. This feature is only available when using `input_folder_path`, because + gradio only allows to use cached file path in the function yet. file_format: str Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] add_timestamp: bool @@ -242,6 +247,17 @@ def transcribe_file(self, ) file_name, file_ext = os.path.splitext(os.path.basename(file)) + if save_same_dir and input_folder_path: + output_dir = os.path.dirname(file) + subtitle, file_path = generate_file( + output_dir=output_dir, + output_file_name=file_name, + output_format=file_format, + result=transcribed_segments, + add_timestamp=add_timestamp, + **writer_options + ) + subtitle, file_path = generate_file( output_dir=self.output_dir, output_file_name=file_name, diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 5fc85cb..3f89f5a 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -60,6 +60,7 @@ def test_transcribe( [audio_path], None, None, + None, "SRT", False, gr.Progress(),