From 3f330f05442b333815fea6073cb0b5f56681e7f8 Mon Sep 17 00:00:00 2001 From: linuxlurak <3813355+linuxlurak@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:35:44 +0100 Subject: [PATCH 1/3] Update default_parameters.yaml Including file_format. --- configs/default_parameters.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml index 3a12d148..e317d25a 100644 --- a/configs/default_parameters.yaml +++ b/configs/default_parameters.yaml @@ -1,5 +1,6 @@ whisper: model_size: "large-v2" + file_format: "SRT" lang: "Automatic Detection" is_translate: false beam_size: 5 From 8a4343101e1467cf33237577df37274fd65a0d54 Mon Sep 17 00:00:00 2001 From: linuxlurak <3813355+linuxlurak@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:37:08 +0100 Subject: [PATCH 2/3] Update app.py Including file_format loading from default config. --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index c22e52df..175522b0 100644 --- a/app.py +++ b/app.py @@ -53,7 +53,7 @@ def create_pipeline_inputs(self): dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION], value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap() else whisper_params["lang"], label=_("Language")) - dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format")) + dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format")) with gr.Row(): cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"), interactive=True) From e2844449722beee014637e807d3426a7a8983f50 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 4 Nov 2024 23:21:57 +0900 Subject: [PATCH 3/3] Add gradio parameter `file_format` to cache --- modules/whisper/base_transcription_pipeline.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 4882abd8..2791dc69 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -71,6 +71,7 @@ def update_model(self, def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), + file_format: str = "SRT", add_timestamp: bool = True, *pipeline_params, ) -> Tuple[List[Segment], float]: @@ -86,6 +87,8 @@ def run(self, Audio input. This can be file path or binary type. progress: gr.Progress Indicator to show progress directly in gradio. + file_format: str + Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"] add_timestamp: bool Whether to add a timestamp at the end of the filename. *pipeline_params: tuple @@ -168,6 +171,7 @@ def run(self, self.cache_parameters( params=params, + file_format=file_format, add_timestamp=add_timestamp ) return result, elapsed_time @@ -224,6 +228,7 @@ def transcribe_file(self, transcribed_segments, time_for_task = self.run( file, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -298,6 +303,7 @@ def transcribe_mic(self, transcribed_segments, time_for_task = self.run( mic_audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -364,6 +370,7 @@ def transcribe_youtube(self, transcribed_segments, time_for_task = self.run( audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -513,7 +520,8 @@ def validate_gradio_values(params: TranscriptionPipelineParams): @staticmethod def cache_parameters( params: TranscriptionPipelineParams, - add_timestamp: bool + file_format: str = "SRT", + add_timestamp: bool = True ): """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) @@ -521,6 +529,7 @@ def cache_parameters( cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp + cached_yaml["whisper"]["file_format"] = file_format supress_token = cached_yaml["whisper"].get("suppress_tokens", None) if supress_token and isinstance(supress_token, list):