diff --git a/app.py b/app.py index db351362..2ed217b6 100644 --- a/app.py +++ b/app.py @@ -19,25 +19,38 @@ def __init__(self, args): self.whisper_inf = self.init_whisper() print(f"Use \"{self.args.whisper_type}\" implementation") print(f"Device \"{self.whisper_inf.device}\" is detected") - self.nllb_inf = NLLBInference() - self.deepl_api = DeepLAPI() + self.nllb_inf = NLLBInference( + model_dir=self.args.nllb_model_dir, + output_dir=self.args.output_dir + ) + self.deepl_api = DeepLAPI( + output_dir=self.args.output_dir + ) def init_whisper(self): whisper_type = self.args.whisper_type.lower().strip() if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]: - whisper_inf = FasterWhisperInference() - whisper_inf.model_dir = self.args.faster_whisper_model_dir + whisper_inf = FasterWhisperInference( + model_dir=self.args.faster_whisper_model_dir, + output_dir=self.args.output_dir + ) elif whisper_type in ["whisper"]: - whisper_inf = WhisperInference() - whisper_inf.model_dir = self.args.whisper_model_dir + whisper_inf = WhisperInference( + model_dir=self.args.whisper_model_dir, + output_dir=self.args.output_dir + ) elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper", "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]: - whisper_inf = InsanelyFastWhisperInference() - whisper_inf.model_dir = self.args.insanely_fast_whisper_model_dir + whisper_inf = InsanelyFastWhisperInference( + model_dir=self.args.insanely_fast_whisper_model_dir, + output_dir=self.args.output_dir + ) else: - whisper_inf = FasterWhisperInference() - whisper_inf.model_dir = self.args.faster_whisper_model_dir + whisper_inf = FasterWhisperInference( + model_dir=self.args.faster_whisper_model_dir, + output_dir=self.args.output_dir + ) return whisper_inf @staticmethod @@ -366,7 +379,7 @@ def launch(self): # Create the parser for command-line arguments parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]') +parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]') parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value') parser.add_argument('--server_name', type=str, default=None, help='Gradio server host') parser.add_argument('--server_port', type=int, default=None, help='Gradio server port') @@ -379,6 +392,8 @@ def launch(self): parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model') parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model') parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model') +parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model') +parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs') _args = parser.parse_args() if __name__ == "__main__": diff --git a/modules/deepl_api.py b/modules/deepl_api.py index c7d52d0f..62751040 100644 --- a/modules/deepl_api.py +++ b/modules/deepl_api.py @@ -82,11 +82,14 @@ class DeepLAPI: - def __init__(self): + def __init__(self, + output_dir: str + ): self.api_interval = 1 self.max_text_batch_size = 50 self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS + self.output_dir = output_dir def translate_deepl(self, auth_key: str, @@ -111,6 +114,7 @@ def translate_deepl(self, Boolean value that is about pro user or not from gr.Checkbox(). progress: gr.Progress Indicator to show progress directly in gradio. + Returns ---------- A List of @@ -140,7 +144,7 @@ def translate_deepl(self, timestamp = datetime.now().strftime("%m%d%H%M%S") file_name = file_name[:-9] - output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt") + output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.srt") write_file(subtitle, output_path) elif file_ext == ".vtt": @@ -160,7 +164,7 @@ def translate_deepl(self, timestamp = datetime.now().strftime("%m%d%H%M%S") file_name = file_name[:-9] - output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt") + output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt") write_file(subtitle, output_path) diff --git a/modules/faster_whisper_inference.py b/modules/faster_whisper_inference.py index 8d9a37af..94387cec 100644 --- a/modules/faster_whisper_inference.py +++ b/modules/faster_whisper_inference.py @@ -17,9 +17,13 @@ class FasterWhisperInference(WhisperBase): - def __init__(self): + def __init__(self, + model_dir: str, + output_dir: str + ): super().__init__( - model_dir=os.path.join("models", "Whisper", "faster-whisper") + model_dir=model_dir, + output_dir=output_dir ) self.model_paths = self.get_model_paths() self.available_models = self.model_paths.keys() diff --git a/modules/insanely_fast_whisper_inference.py b/modules/insanely_fast_whisper_inference.py index 47ae147f..2404270b 100644 --- a/modules/insanely_fast_whisper_inference.py +++ b/modules/insanely_fast_whisper_inference.py @@ -15,9 +15,13 @@ class InsanelyFastWhisperInference(WhisperBase): - def __init__(self): + def __init__(self, + model_dir: str, + output_dir: str + ): super().__init__( - model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper") + model_dir=model_dir, + output_dir=output_dir ) openai_models = whisper.available_models() distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] diff --git a/modules/nllb_inference.py b/modules/nllb_inference.py index 43eb4e36..73fcbf8a 100644 --- a/modules/nllb_inference.py +++ b/modules/nllb_inference.py @@ -6,9 +6,13 @@ class NLLBInference(TranslationBase): - def __init__(self): + def __init__(self, + model_dir: str, + output_dir: str + ): super().__init__( - model_dir=os.path.join("models", "NLLB") + model_dir=model_dir, + output_dir=output_dir ) self.tokenizer = None self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"] diff --git a/modules/translation_base.py b/modules/translation_base.py index 1aa37504..0a23b775 100644 --- a/modules/translation_base.py +++ b/modules/translation_base.py @@ -11,11 +11,14 @@ class TranslationBase(ABC): def __init__(self, - model_dir: str): + model_dir: str, + output_dir: str): super().__init__() self.model = None self.model_dir = model_dir + self.output_dir = output_dir os.makedirs(self.model_dir, exist_ok=True) + os.makedirs(self.output_dir, exist_ok=True) self.current_model_size = None self.device = self.get_device() @@ -87,7 +90,7 @@ def translate_file(self, timestamp = datetime.now().strftime("%m%d%H%M%S") if add_timestamp: - output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}") + output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt") else: output_path = os.path.join("outputs", "translations", f"{file_name}.srt") @@ -102,9 +105,9 @@ def translate_file(self, timestamp = datetime.now().strftime("%m%d%H%M%S") if add_timestamp: - output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}") + output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt") else: - output_path = os.path.join("outputs", "translations", f"{file_name}.vtt") + output_path = os.path.join(self.output_dir, "translations", f"{file_name}.vtt") write_file(subtitle, output_path) files_info[file_name] = subtitle diff --git a/modules/whisper_Inference.py b/modules/whisper_Inference.py index e2a071f1..2a16d4db 100644 --- a/modules/whisper_Inference.py +++ b/modules/whisper_Inference.py @@ -11,9 +11,13 @@ class WhisperInference(WhisperBase): - def __init__(self): + def __init__(self, + model_dir: str, + output_dir: str + ): super().__init__( - model_dir=os.path.join("models", "Whisper") + model_dir=model_dir, + output_dir=output_dir ) def transcribe(self, diff --git a/modules/whisper_base.py b/modules/whisper_base.py index 78a40458..4b6e43bd 100644 --- a/modules/whisper_base.py +++ b/modules/whisper_base.py @@ -15,10 +15,14 @@ class WhisperBase(ABC): def __init__(self, - model_dir: str): + model_dir: str, + output_dir: str + ): self.model = None self.current_model_size = None self.model_dir = model_dir + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.model_dir, exist_ok=True) self.available_models = whisper.available_models() self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) @@ -88,7 +92,8 @@ def transcribe_file(self, file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, - file_format=file_format + file_format=file_format, + output_dir=self.output_dir ) files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path} @@ -152,7 +157,8 @@ def transcribe_mic(self, file_name="Mic", transcribed_segments=transcribed_segments, add_timestamp=True, - file_format=file_format + file_format=file_format, + output_dir=self.output_dir ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" @@ -211,7 +217,8 @@ def transcribe_youtube(self, file_name=file_name, transcribed_segments=transcribed_segments, add_timestamp=add_timestamp, - file_format=file_format + file_format=file_format, + output_dir=self.output_dir ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" @@ -237,6 +244,7 @@ def generate_and_write_file(file_name: str, transcribed_segments: list, add_timestamp: bool, file_format: str, + output_dir: str ) -> str: """ Writes subtitle file @@ -251,6 +259,8 @@ def generate_and_write_file(file_name: str, Determines whether to add a timestamp to the end of the filename. file_format: str File format to write. Supported formats: [SRT, WebVTT, txt] + output_dir: str + Directory path of the output Returns ---------- @@ -261,9 +271,9 @@ def generate_and_write_file(file_name: str, """ timestamp = datetime.now().strftime("%m%d%H%M%S") if add_timestamp: - output_path = os.path.join("outputs", f"{file_name}-{timestamp}") + output_path = os.path.join(output_dir, f"{file_name}-{timestamp}") else: - output_path = os.path.join("outputs", f"{file_name}") + output_path = os.path.join(output_dir, f"{file_name}") if file_format == "SRT": content = get_srt(transcribed_segments)