Skip to content

Commit

Permalink
Merge pull request #290 from jhj0517/fix/defaults
Browse files Browse the repository at this point in the history
Add defaults to functions
  • Loading branch information
jhj0517 authored Sep 24, 2024
2 parents 6a7425a + aa11c47 commit 29cce95
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion modules/translation/nllb_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def update_model(self,
model_size: str,
src_lang: str,
tgt_lang: str,
progress: gr.Progress
progress: gr.Progress = gr.Progress()
):
if model_size != self.current_model_size or self.model is None:
print("\nInitializing NLLB Model..\n")
Expand Down
2 changes: 1 addition & 1 deletion modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def update_model(self,
model_size: str,
src_lang: str,
tgt_lang: str,
progress: gr.Progress
progress: gr.Progress = gr.Progress()
):
pass

Expand Down
4 changes: 2 additions & 2 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,

def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
"""
Expand Down Expand Up @@ -126,7 +126,7 @@ def transcribe(self,
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress
progress: gr.Progress = gr.Progress()
):
"""
Update current model setting
Expand Down
4 changes: 2 additions & 2 deletions modules/whisper/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self,

def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ def transcribe(self,
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
):
"""
Update current model setting
Expand Down
4 changes: 2 additions & 2 deletions modules/whisper/whisper_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,

def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
"""
Expand Down Expand Up @@ -79,7 +79,7 @@ def progress_callback(progress_value):
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
):
"""
Update current model setting
Expand Down
25 changes: 13 additions & 12 deletions modules/whisper/whisper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,
@abstractmethod
def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress,
progress: gr.Progress = gr.Progress(),
*whisper_params,
):
"""Inference whisper model to transcribe"""
Expand All @@ -63,7 +63,7 @@ def transcribe(self,
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress
progress: gr.Progress = gr.Progress()
):
"""Initialize whisper model"""
pass
Expand Down Expand Up @@ -171,10 +171,10 @@ def run(self,
return result, elapsed_time

def transcribe_file(self,
files: list,
input_folder_path: str,
file_format: str,
add_timestamp: bool,
files: Optional[List] = None,
input_folder_path: Optional[str] = None,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
Expand Down Expand Up @@ -250,8 +250,8 @@ def transcribe_file(self,

def transcribe_mic(self,
mic_audio: str,
file_format: str,
add_timestamp: bool,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
Expand Down Expand Up @@ -306,8 +306,8 @@ def transcribe_mic(self,

def transcribe_youtube(self,
youtube_link: str,
file_format: str,
add_timestamp: bool,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
Expand Down Expand Up @@ -411,11 +411,12 @@ def generate_and_write_file(file_name: str,
else:
output_path = os.path.join(output_dir, f"{file_name}")

if file_format == "SRT":
file_format = file_format.strip().lower()
if file_format == "srt":
content = get_srt(transcribed_segments)
output_path += '.srt'

elif file_format == "WebVTT":
elif file_format == "webvtt":
content = get_vtt(transcribed_segments)
output_path += '.vtt'

Expand Down

0 comments on commit 29cce95

Please sign in to comment.