Skip to content

Commit

Permalink
Merge pull request #173 from jhj0517/fix/refactor-scalability
Browse files Browse the repository at this point in the history
Refactor scalability with several whisper implementation type
  • Loading branch information
jhj0517 authored Jun 18, 2024
2 parents d868316 + 091209e commit d843d51
Show file tree
Hide file tree
Showing 8 changed files with 550 additions and 730 deletions.
34 changes: 21 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import gradio as gr
import os
import argparse
import webbrowser

from modules.whisper_Inference import WhisperInference
from modules.faster_whisper_inference import FasterWhisperInference
Expand All @@ -16,17 +15,26 @@ class App:
def __init__(self, args):
self.args = args
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
if isinstance(self.whisper_inf, FasterWhisperInference):
self.whisper_inf.model_dir = args.faster_whisper_model_dir
print("Use Faster Whisper implementation")
else:
self.whisper_inf.model_dir = args.whisper_model_dir
print("Use Open AI Whisper implementation")
self.whisper_inf = self.init_whisper()
print(f"Use \"{self.args.whisper_type}\" implementation")
print(f"Device \"{self.whisper_inf.device}\" is detected")
self.nllb_inf = NLLBInference()
self.deepl_api = DeepLAPI()

def init_whisper(self):
whisper_type = self.args.whisper_type.lower().strip()

if whisper_type in ["faster_whisper", "faster-whisper"]:
whisper_inf = FasterWhisperInference()
whisper_inf.model_dir = self.args.faster_whisper_model_dir
if whisper_type in ["whisper"]:
whisper_inf = WhisperInference()
whisper_inf.model_dir = self.args.whisper_model_dir
else:
whisper_inf = FasterWhisperInference()
whisper_inf.model_dir = self.args.faster_whisper_model_dir
return whisper_inf

@staticmethod
def open_folder(folder_path: str):
if os.path.exists(folder_path):
Expand Down Expand Up @@ -61,7 +69,7 @@ def launch(self):
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
with gr.Row():
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
Expand Down Expand Up @@ -135,7 +143,7 @@ def launch(self):
with gr.Row():
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
interactive=True)
with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
Expand Down Expand Up @@ -201,7 +209,7 @@ def launch(self):
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
with gr.Row():
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
Expand Down Expand Up @@ -289,7 +297,7 @@ def launch(self):

with gr.TabItem("NLLB"): # sub tab2
with gr.Row():
dd_nllb_model = gr.Dropdown(label="Model", value=self.nllb_inf.default_model_size,
dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
choices=self.nllb_inf.available_models)
dd_nllb_sourcelang = gr.Dropdown(label="Source Language",
choices=self.nllb_inf.available_source_langs)
Expand Down Expand Up @@ -332,7 +340,7 @@ def launch(self):

# Create the parser for command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--disable_faster_whisper', type=bool, default=False, nargs='?', const=True, help='Disable the faster_whisper implementation. faster_whipser is implemented by https://github.com/guillaumekln/faster-whisper')
parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]')
parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
Expand Down
23 changes: 0 additions & 23 deletions modules/base_interface.py

This file was deleted.

Loading

0 comments on commit d843d51

Please sign in to comment.