Skip to content

Commit

Permalink
ADD: available whisper models and default whisper model
Browse files Browse the repository at this point in the history
  • Loading branch information
ruokolt committed Feb 1, 2024
1 parent 7f12ce6 commit fe16279
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,7 @@
"welsh": "cy",
}


supported_languages_reverse = {value: key for key, value in supported_languages.items()}

available_whisper_models = ["large-v2", "large-v3"]
default_whisper_model = "large-v3"
22 changes: 18 additions & 4 deletions src/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -76,6 +75,12 @@ def get_argument_parser():
default=os.getenv("SPEECH2TEXT_LANGUAGE"),
help="Audio language. Optional but recommended.",
)
parser.add_argument(
"--SPEECH2TEXT_WHISPER_MODEL",
type=str,
default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"),
help=f"Whisper model. Defaults to {settings.default_whisper_model}.",
)

return parser

Expand Down Expand Up @@ -218,7 +223,7 @@ def write_alignment_to_txt_file(alignment, output_file_stem):


def load_whisperx_model(
name: str = "large-v3",
name: str,
device: Optional[Union[str, torch.device]] = None,
):
if device is None:
Expand Down Expand Up @@ -254,9 +259,11 @@ def read_input_file_from_array_file(input_file, slurm_array_task_id):
return new_input_file


def transcribe(file: str, language: str, result_list) -> TranscriptionResult:
def transcribe(
file: str, model_name: str, language: str, result_list
) -> TranscriptionResult:
batch_size = calculate_max_batch_size()
model = load_whisperx_model()
model = load_whisperx_model(model_name)
try:
segs, _ = model.transcribe(
file, batch_size=batch_size, language=language
Expand Down Expand Up @@ -317,6 +324,12 @@ def main():

logger.info(f".. .. Wav conversion done in {time.time()-t0:.1f} seconds")

# Check Whisper model name if given
model_name = args.SPEECH2TEXT_WHISPER_MODEL
if model_name is None:
model_name = settings.default_whisper_model

# Check language if given
language = args.SPEECH2TEXT_LANGUAGE
if language and language.lower() in settings.supported_languages:
language = settings.supported_languages[language.lower()]
Expand All @@ -328,6 +341,7 @@ def main():
target=transcribe,
args=(
input_file_wav,
model_name,
language,
shared_dict,
),
Expand Down
29 changes: 28 additions & 1 deletion src/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
import shlex
import subprocess
import time
from pathlib import Path, PosixPath

import settings
Expand Down Expand Up @@ -62,6 +61,12 @@ def get_argument_parser():
default=os.getenv("SPEECH2TEXT_LANGUAGE"),
help="Language. Optional.",
)
parser.add_argument(
"--SPEECH2TEXT_WHISPER_MODEL",
type=str,
default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"),
help=f"Whisper model. Default is {settings.default_whisper_model}.",
)

return parser

Expand Down Expand Up @@ -329,6 +334,24 @@ def check_email(email):
)


def check_whisper_model(name):
if name is None:
print(
f"Whisper model not given, using default '{settings.default_whisper_model}'.\n"
)
return True

elif name in settings.available_whisper_models:
print(f"Given Whisper model '{name}' is available.\n")
return True

print(
f"Submission failed: Given Whisper model '{name}' is not among available models:\n\n{' '.join(settings.available_whisper_models)}.\n"
)

return False


def main():
# Parse arguments
parser = get_argument_parser()
Expand All @@ -345,6 +368,10 @@ def main():
# Check email
check_email(args.SPEECH2TEXT_EMAIL)

# Check Whisper model name
if not check_whisper_model(args.SPEEHCH2TEXT_WHISPER_MODEL):
return

# Notify about temporary folder location
print(
f"Log files (.out) and batch submit scripts (.sh) will be written to: {args.SPEECH2TEXT_TMP}\n"
Expand Down

0 comments on commit fe16279

Please sign in to comment.