Skip to content

Commit

Permalink
Merge pull request #178 from jhj0517/feature/add-output-dir
Browse files Browse the repository at this point in the history
Add output dir arg
  • Loading branch information
jhj0517 authored Jun 25, 2024
2 parents 0450240 + 94904d8 commit a230be5
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 32 deletions.
37 changes: 26 additions & 11 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,38 @@ def __init__(self, args):
self.whisper_inf = self.init_whisper()
print(f"Use \"{self.args.whisper_type}\" implementation")
print(f"Device \"{self.whisper_inf.device}\" is detected")
self.nllb_inf = NLLBInference()
self.deepl_api = DeepLAPI()
self.nllb_inf = NLLBInference(
model_dir=self.args.nllb_model_dir,
output_dir=self.args.output_dir
)
self.deepl_api = DeepLAPI(
output_dir=self.args.output_dir
)

def init_whisper(self):
whisper_type = self.args.whisper_type.lower().strip()

if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
whisper_inf = FasterWhisperInference()
whisper_inf.model_dir = self.args.faster_whisper_model_dir
whisper_inf = FasterWhisperInference(
model_dir=self.args.faster_whisper_model_dir,
output_dir=self.args.output_dir
)
elif whisper_type in ["whisper"]:
whisper_inf = WhisperInference()
whisper_inf.model_dir = self.args.whisper_model_dir
whisper_inf = WhisperInference(
model_dir=self.args.whisper_model_dir,
output_dir=self.args.output_dir
)
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
whisper_inf = InsanelyFastWhisperInference()
whisper_inf.model_dir = self.args.insanely_fast_whisper_model_dir
whisper_inf = InsanelyFastWhisperInference(
model_dir=self.args.insanely_fast_whisper_model_dir,
output_dir=self.args.output_dir
)
else:
whisper_inf = FasterWhisperInference()
whisper_inf.model_dir = self.args.faster_whisper_model_dir
whisper_inf = FasterWhisperInference(
model_dir=self.args.faster_whisper_model_dir,
output_dir=self.args.output_dir
)
return whisper_inf

@staticmethod
Expand Down Expand Up @@ -366,7 +379,7 @@ def launch(self):

# Create the parser for command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]')
parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]')
parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
Expand All @@ -379,6 +392,8 @@ def launch(self):
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
_args = parser.parse_args()

if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions modules/deepl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@


class DeepLAPI:
def __init__(self):
def __init__(self,
output_dir: str
):
self.api_interval = 1
self.max_text_batch_size = 50
self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
self.output_dir = output_dir

def translate_deepl(self,
auth_key: str,
Expand All @@ -111,6 +114,7 @@ def translate_deepl(self,
Boolean value that is about pro user or not from gr.Checkbox().
progress: gr.Progress
Indicator to show progress directly in gradio.
Returns
----------
A List of
Expand Down Expand Up @@ -140,7 +144,7 @@ def translate_deepl(self,
timestamp = datetime.now().strftime("%m%d%H%M%S")

file_name = file_name[:-9]
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.srt")
write_file(subtitle, output_path)

elif file_ext == ".vtt":
Expand All @@ -160,7 +164,7 @@ def translate_deepl(self,
timestamp = datetime.now().strftime("%m%d%H%M%S")

file_name = file_name[:-9]
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt")
output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")

write_file(subtitle, output_path)

Expand Down
8 changes: 6 additions & 2 deletions modules/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@


class FasterWhisperInference(WhisperBase):
def __init__(self):
def __init__(self,
model_dir: str,
output_dir: str
):
super().__init__(
model_dir=os.path.join("models", "Whisper", "faster-whisper")
model_dir=model_dir,
output_dir=output_dir
)
self.model_paths = self.get_model_paths()
self.available_models = self.model_paths.keys()
Expand Down
8 changes: 6 additions & 2 deletions modules/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@


class InsanelyFastWhisperInference(WhisperBase):
def __init__(self):
def __init__(self,
model_dir: str,
output_dir: str
):
super().__init__(
model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper")
model_dir=model_dir,
output_dir=output_dir
)
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
Expand Down
8 changes: 6 additions & 2 deletions modules/nllb_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@


class NLLBInference(TranslationBase):
def __init__(self):
def __init__(self,
model_dir: str,
output_dir: str
):
super().__init__(
model_dir=os.path.join("models", "NLLB")
model_dir=model_dir,
output_dir=output_dir
)
self.tokenizer = None
self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
Expand Down
11 changes: 7 additions & 4 deletions modules/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

class TranslationBase(ABC):
def __init__(self,
model_dir: str):
model_dir: str,
output_dir: str):
super().__init__()
self.model = None
self.model_dir = model_dir
self.output_dir = output_dir
os.makedirs(self.model_dir, exist_ok=True)
os.makedirs(self.output_dir, exist_ok=True)
self.current_model_size = None
self.device = self.get_device()

Expand Down Expand Up @@ -87,7 +90,7 @@ def translate_file(self,

timestamp = datetime.now().strftime("%m%d%H%M%S")
if add_timestamp:
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
else:
output_path = os.path.join("outputs", "translations", f"{file_name}.srt")

Expand All @@ -102,9 +105,9 @@ def translate_file(self,

timestamp = datetime.now().strftime("%m%d%H%M%S")
if add_timestamp:
output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")
else:
output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
output_path = os.path.join(self.output_dir, "translations", f"{file_name}.vtt")

write_file(subtitle, output_path)
files_info[file_name] = subtitle
Expand Down
8 changes: 6 additions & 2 deletions modules/whisper_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@


class WhisperInference(WhisperBase):
def __init__(self):
def __init__(self,
model_dir: str,
output_dir: str
):
super().__init__(
model_dir=os.path.join("models", "Whisper")
model_dir=model_dir,
output_dir=output_dir
)

def transcribe(self,
Expand Down
22 changes: 16 additions & 6 deletions modules/whisper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

class WhisperBase(ABC):
def __init__(self,
model_dir: str):
model_dir: str,
output_dir: str
):
self.model = None
self.current_model_size = None
self.model_dir = model_dir
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.model_dir, exist_ok=True)
self.available_models = whisper.available_models()
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
Expand Down Expand Up @@ -88,7 +92,8 @@ def transcribe_file(self,
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format
file_format=file_format,
output_dir=self.output_dir
)
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}

Expand Down Expand Up @@ -152,7 +157,8 @@ def transcribe_mic(self,
file_name="Mic",
transcribed_segments=transcribed_segments,
add_timestamp=True,
file_format=file_format
file_format=file_format,
output_dir=self.output_dir
)

result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
Expand Down Expand Up @@ -211,7 +217,8 @@ def transcribe_youtube(self,
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format
file_format=file_format,
output_dir=self.output_dir
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"

Expand All @@ -237,6 +244,7 @@ def generate_and_write_file(file_name: str,
transcribed_segments: list,
add_timestamp: bool,
file_format: str,
output_dir: str
) -> str:
"""
Writes subtitle file
Expand All @@ -251,6 +259,8 @@ def generate_and_write_file(file_name: str,
Determines whether to add a timestamp to the end of the filename.
file_format: str
File format to write. Supported formats: [SRT, WebVTT, txt]
output_dir: str
Directory path of the output
Returns
----------
Expand All @@ -261,9 +271,9 @@ def generate_and_write_file(file_name: str,
"""
timestamp = datetime.now().strftime("%m%d%H%M%S")
if add_timestamp:
output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
else:
output_path = os.path.join("outputs", f"{file_name}")
output_path = os.path.join(output_dir, f"{file_name}")

if file_format == "SRT":
content = get_srt(transcribed_segments)
Expand Down

0 comments on commit a230be5

Please sign in to comment.