diff --git a/README.md b/README.md index 14454af..aff9e60 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The required models have been downloaded beforehand from Hugging Face and saved ### Faster Whisper -We use `large-v3`, currently the biggest and most accurate multilingual [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) model available. Languages supported by the model are: +We support `large-v2` and `large-v3` (default) multilingual [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) models. Languages supported by the models are: afrikaans, arabic, armenian, azerbaijani, belarusian, bosnian, bulgarian, catalan, chinese, croatian, czech, danish, dutch, english, estonian, finnish, french, galician, @@ -27,7 +27,11 @@ norwegian, persian, polish, portuguese, romanian, russian, serbian, slovak, slov spanish, swahili, swedish, tagalog, tamil, thai, turkish, ukrainian, urdu, vietnamese, welsh -The model is covered by the [MIT licence]((https://huggingface.co/models?license=license:mit)) and has been pre-downloaded from Hugging Face to +The models are covered by the [MIT licence]((https://huggingface.co/models?license=license:mit)) and have been pre-downloaded from Hugging Face to + +`/scratch/shareddata/dldata/huggingface-hub-cache/hub/models--Systran--faster-whisper-large-v2` + +and `/scratch/shareddata/dldata/huggingface-hub-cache/hub/models--Systran--faster-whisper-large-v3` @@ -168,6 +172,9 @@ The audio file(s) can be in any common audio (.wav, .mp3, .aff, etc.) or video ( The transcription and diarization results (.txt and .csv files) corresponding to each audio file will be written to `results/` next to the file. See [below](#output-formats) for details. +> **__NOTE:__** While speech2text by default uses the `large-v3` model, user can specify the model with the `SPEECH2TEXT_WHISPER_MODEL` environment variable. Note, however, that only `large-v2` and `large-v3` models have been pre-downloaded. + + ## Output formats The output formats are `.csv` and `.txt`. For example, output files corresponding to input audio files diff --git a/bin/speech2text b/bin/speech2text index 973d889..cd49e33 100755 --- a/bin/speech2text +++ b/bin/speech2text @@ -77,12 +77,6 @@ do esac done -# Set env variables -export OMP_NUM_THREADS=$SPEECH2TEXT_CPUS_PER_TASK -export OMP_PROC_BIND=true -export KMP_AFFINITY=granularity=fine,compact,1,0 -export KMP_AFFINITY=granularity=fine,compact,1,0 - # Folder in which this script is located # https://stackoverflow.com/questions/39340169/dir-cd-dirname-bash-source0-pwd-how-does-that-work SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" diff --git a/modules/speech2text/20240130.lua b/modules/speech2text/20240130.lua new file mode 100644 index 0000000..3103bf6 --- /dev/null +++ b/modules/speech2text/20240130.lua @@ -0,0 +1,81 @@ +help_text = [[ + +This app does speech2text with diarization. + +Example run on a single file: + + export SPEECH2TEXT_EMAIL=john.smith@aalto.fi + export SPEECH2TEXT_LANGUAGE=finnish + speech2text audiofile.mp3 + +Example run on a folder containing one or more audio file: + + export SPEECH2TEXT_EMAIL=jane.smith@aalto.fi + export SPEECH2TEXT_LANGUAGE=finnish + speech2text audiofiles/ + +The audio files can be in any common audio (.wav, .mp3, .aff, etc.) or video (.mp4, .mov, etc.) format. + +The speech2text app writes result files to a subfolder results/ next to each audio file. +Result filenames are the audio filename with .txt and .csv extensions. For example, result files +corresponding to audiofile.mp3 are written to results/audiofile.txt and results/audiofile.csv. +Result files in a folder audiofiles/ will be written to folder audiofiles/results/. + +Notification emails will be sent to SPEECH2TEXT_EMAIL. If SPEECH2TEXT_EMAIL is left +unspecified, no notifications are sent. + +Supported languages are: + +afrikaans, arabic, armenian, azerbaijani, belarusian, bosnian, bulgarian, catalan, +chinese, croatian, czech, danish, dutch, english, estonian, finnish, french, galician, +german, greek, hebrew, hindi, hungarian, icelandic, indonesian, italian, japanese, +kannada, kazakh, korean, latvian, lithuanian, macedonian, malay, marathi, maori, nepali, +norwegian, persian, polish, portuguese, romanian, russian, serbian, slovak, slovenian, +spanish, swahili, swedish, tagalog, tamil, thai, turkish, ukrainian, urdu, vietnamese, +welsh + +You can leave the language variable SPEECH2TEXT_LANGUAGE unspecified, in which case +speech2text tries to detect the language automatically. Specifying the language +explicitly is, however, recommended. +]] + +local version = "20240130" +whatis("Name : Aalto speech2text") +whatis("Version :" .. version) +help(help_text) + +local speech2text = "/share/apps/manual_installations/speech2text/" .. version .. "/bin/" +local conda_env = "/share/apps/manual_installations/speech2text/" .. version .. "/env/bin/" + +prepend_path("PATH", speech2text) +prepend_path("PATH", conda_env) + +local hf_home = "/scratch/shareddata/dldata/huggingface-hub-cache/" +local pyannote_cache = hf_home .. "hub/" +local torch_home = "/scratch/shareddata/speech2text" +local pyannote_config = "/share/apps/manual_installations/speech2text/" .. version .. "/pyannote/config.yml" +local numba_cache = "/tmp" +local mplconfigdir = "/tmp" + +pushenv("HF_HOME", hf_home) +pushenv("PYANNOTE_CACHE", pyannote_cache) +pushenv("TORCH_HOME", torch_home) +pushenv("XDG_CACHE_HOME", torch_home) +pushenv("PYANNOTE_CONFIG", pyannote_config) +pushenv("NUMBA_CACHE_DIR", numba_cache) +pushenv("MPLCONFIGDIR", mplconfigdir) + +local speech2text_mem = "8G" +local speech2text_cpus_per_task = "6" +local speech2text_tmp = os.getenv("WRKDIR") .. "/.speech2text" + +pushenv("SPEECH2TEXT_MEM", speech2text_mem) +pushenv("SPEECH2TEXT_CPUS_PER_TASK", speech2text_cpus_per_task) +pushenv("SPEECH2TEXT_TMP", speech2text_tmp) + +pushenv("HF_HUB_OFFLINE", "1") + +if mode() == "load" then + LmodMessage("For more information, run 'module spider speech2text/" .. version .. "'") +end + diff --git a/src/settings.py b/src/settings.py index 592d93f..7f9bd73 100644 --- a/src/settings.py +++ b/src/settings.py @@ -58,5 +58,7 @@ "welsh": "cy", } - supported_languages_reverse = {value: key for key, value in supported_languages.items()} + +available_whisper_models = ["large-v2", "large-v3"] +default_whisper_model = "large-v3" diff --git a/src/speech2text.py b/src/speech2text.py index 8b6d2b4..cd3457d 100644 --- a/src/speech2text.py +++ b/src/speech2text.py @@ -9,14 +9,12 @@ from pathlib import Path from typing import Optional, Union -import numpy as np import pandas as pd import torch import torch.multiprocessing as mp import whisperx from numba.core.errors import (NumbaDeprecationWarning, NumbaPendingDeprecationWarning) -from pydub import AudioSegment from whisperx.types import TranscriptionResult import settings @@ -76,6 +74,12 @@ def get_argument_parser(): default=os.getenv("SPEECH2TEXT_LANGUAGE"), help="Audio language. Optional but recommended.", ) + parser.add_argument( + "--SPEECH2TEXT_WHISPER_MODEL", + type=str, + default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"), + help=f"Whisper model. Defaults to {settings.default_whisper_model}.", + ) return parser @@ -218,12 +222,18 @@ def write_alignment_to_txt_file(alignment, output_file_stem): def load_whisperx_model( - name: str = "large-v3", + name: str, device: Optional[Union[str, torch.device]] = None, ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" + if name not in settings.available_whisper_models: + logger.warning( + f"Specified model '{name}' not among available models: {settings.available_whisper_models}. Opting to use the default model '{settings.default_whisper_model}' instead" + ) + name = settings.default_whisper_model + compute_type = "float16" if device == "cuda" else "int8" try: model = whisperx.load_model( @@ -254,9 +264,12 @@ def read_input_file_from_array_file(input_file, slurm_array_task_id): return new_input_file -def transcribe(file: str, language: str, result_list) -> TranscriptionResult: +def transcribe( + file: str, model_name: str, language: str, result_list +) -> TranscriptionResult: batch_size = calculate_max_batch_size() - model = load_whisperx_model() + model = load_whisperx_model(model_name) + try: segs, _ = model.transcribe( file, batch_size=batch_size, language=language @@ -317,9 +330,32 @@ def main(): logger.info(f".. .. Wav conversion done in {time.time()-t0:.1f} seconds") + # Check Whisper model name if given + model_name = args.SPEECH2TEXT_WHISPER_MODEL + if model_name is None: + model_name = settings.default_whisper_model + + # Check language if given language = args.SPEECH2TEXT_LANGUAGE - if language and language.lower() in settings.supported_languages: - language = settings.supported_languages[language.lower()] + if language: + if language.lower() in settings.supported_languages.keys(): + # Language is given in OK long form: convert to short form (two-letter abbreviation) + language = settings.supported_languages[language.lower()] + elif language.lower() in settings.supported_languages.values(): + # Language is given in OK short form + pass + else: + # Given language not OK + pretty_language_list = ", ".join( + [ + f"{lang} ({short})" + for lang, short in settings.supported_languages.items() + ] + ) + logger.warning( + f"Given language '{language}' not found among supported languages: {pretty_language_list}. Opting to detect language automatically" + ) + language = None with mp.Manager() as manager: shared_dict = manager.dict() @@ -328,6 +364,7 @@ def main(): target=transcribe, args=( input_file_wav, + model_name, language, shared_dict, ), diff --git a/src/submit.py b/src/submit.py index a8caac5..e120681 100644 --- a/src/submit.py +++ b/src/submit.py @@ -11,7 +11,6 @@ import os import shlex import subprocess -import time from pathlib import Path, PosixPath import settings @@ -62,6 +61,12 @@ def get_argument_parser(): default=os.getenv("SPEECH2TEXT_LANGUAGE"), help="Language. Optional.", ) + parser.add_argument( + "--SPEECH2TEXT_WHISPER_MODEL", + type=str, + default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"), + help=f"Whisper model. Default is {settings.default_whisper_model}.", + ) return parser @@ -195,6 +200,8 @@ def create_sbatch_script_for_array_job( #SBATCH --mail-type=BEGIN #SBATCH --mail-type=END #SBATCH --mail-type=FAIL +export OMP_NUM_THREADS={cpus_per_task} +export KMP_AFFINITY=granularity=fine,compact python3 {python_source_dir}/speech2text.py {input_file} """ tmp_file_sh = (Path(tmp_dir) / str(job_name)).with_suffix(".sh") @@ -327,6 +334,24 @@ def check_email(email): ) +def check_whisper_model(name): + if name is None: + print( + f"Whisper model not given, using default '{settings.default_whisper_model}'.\n" + ) + return True + + elif name in settings.available_whisper_models: + print(f"Given Whisper model '{name}' is available.\n") + return True + + print( + f"Submission failed: Given Whisper model '{name}' is not among available models:\n\n{' '.join(settings.available_whisper_models)}.\n" + ) + + return False + + def main(): # Parse arguments parser = get_argument_parser() @@ -343,6 +368,10 @@ def main(): # Check email check_email(args.SPEECH2TEXT_EMAIL) + # Check Whisper model name + if not check_whisper_model(args.SPEECH2TEXT_WHISPER_MODEL): + return + # Notify about temporary folder location print( f"Log files (.out) and batch submit scripts (.sh) will be written to: {args.SPEECH2TEXT_TMP}\n"