Skip to content

Commit

Permalink
Version 20240130 (#5)
Browse files Browse the repository at this point in the history
Create a new version modules/speech2text/20240130 compatible with the migration from whisper to whisperx and from large-v2 to large-v3. The model version is controlled using a new environment variable SPEECH2TEXT_WHISPER_MODEL, default is "large-v3".
  • Loading branch information
ruokolt authored Feb 8, 2024
1 parent a67f271 commit cfb5ba3
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 17 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The required models have been downloaded beforehand from Hugging Face and saved

### Faster Whisper

We use `large-v3`, currently the biggest and most accurate multilingual [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) model available. Languages supported by the model are:
We support `large-v2` and `large-v3` (default) multilingual [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) models. Languages supported by the models are:

afrikaans, arabic, armenian, azerbaijani, belarusian, bosnian, bulgarian, catalan,
chinese, croatian, czech, danish, dutch, english, estonian, finnish, french, galician,
Expand All @@ -27,7 +27,11 @@ norwegian, persian, polish, portuguese, romanian, russian, serbian, slovak, slov
spanish, swahili, swedish, tagalog, tamil, thai, turkish, ukrainian, urdu, vietnamese,
welsh

The model is covered by the [MIT licence]((https://huggingface.co/models?license=license:mit)) and has been pre-downloaded from Hugging Face to
The models are covered by the [MIT licence]((https://huggingface.co/models?license=license:mit)) and have been pre-downloaded from Hugging Face to

`/scratch/shareddata/dldata/huggingface-hub-cache/hub/models--Systran--faster-whisper-large-v2`

and

`/scratch/shareddata/dldata/huggingface-hub-cache/hub/models--Systran--faster-whisper-large-v3`

Expand Down Expand Up @@ -168,6 +172,9 @@ The audio file(s) can be in any common audio (.wav, .mp3, .aff, etc.) or video (
The transcription and diarization results (.txt and .csv files) corresponding to each audio file will be written to `results/` next to the file. See [below](#output-formats) for details.


> **__NOTE:__** While speech2text by default uses the `large-v3` model, user can specify the model with the `SPEECH2TEXT_WHISPER_MODEL` environment variable. Note, however, that only `large-v2` and `large-v3` models have been pre-downloaded.

## Output formats

The output formats are `.csv` and `.txt`. For example, output files corresponding to input audio files
Expand Down
6 changes: 0 additions & 6 deletions bin/speech2text
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ do
esac
done

# Set env variables
export OMP_NUM_THREADS=$SPEECH2TEXT_CPUS_PER_TASK
export OMP_PROC_BIND=true
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_AFFINITY=granularity=fine,compact,1,0

# Folder in which this script is located
# https://stackoverflow.com/questions/39340169/dir-cd-dirname-bash-source0-pwd-how-does-that-work
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
Expand Down
81 changes: 81 additions & 0 deletions modules/speech2text/20240130.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
help_text = [[
This app does speech2text with diarization.
Example run on a single file:
export [email protected]
export SPEECH2TEXT_LANGUAGE=finnish
speech2text audiofile.mp3
Example run on a folder containing one or more audio file:
export [email protected]
export SPEECH2TEXT_LANGUAGE=finnish
speech2text audiofiles/
The audio files can be in any common audio (.wav, .mp3, .aff, etc.) or video (.mp4, .mov, etc.) format.
The speech2text app writes result files to a subfolder results/ next to each audio file.
Result filenames are the audio filename with .txt and .csv extensions. For example, result files
corresponding to audiofile.mp3 are written to results/audiofile.txt and results/audiofile.csv.
Result files in a folder audiofiles/ will be written to folder audiofiles/results/.
Notification emails will be sent to SPEECH2TEXT_EMAIL. If SPEECH2TEXT_EMAIL is left
unspecified, no notifications are sent.
Supported languages are:
afrikaans, arabic, armenian, azerbaijani, belarusian, bosnian, bulgarian, catalan,
chinese, croatian, czech, danish, dutch, english, estonian, finnish, french, galician,
german, greek, hebrew, hindi, hungarian, icelandic, indonesian, italian, japanese,
kannada, kazakh, korean, latvian, lithuanian, macedonian, malay, marathi, maori, nepali,
norwegian, persian, polish, portuguese, romanian, russian, serbian, slovak, slovenian,
spanish, swahili, swedish, tagalog, tamil, thai, turkish, ukrainian, urdu, vietnamese,
welsh
You can leave the language variable SPEECH2TEXT_LANGUAGE unspecified, in which case
speech2text tries to detect the language automatically. Specifying the language
explicitly is, however, recommended.
]]

local version = "20240130"
whatis("Name : Aalto speech2text")
whatis("Version :" .. version)
help(help_text)

local speech2text = "/share/apps/manual_installations/speech2text/" .. version .. "/bin/"
local conda_env = "/share/apps/manual_installations/speech2text/" .. version .. "/env/bin/"

prepend_path("PATH", speech2text)
prepend_path("PATH", conda_env)

local hf_home = "/scratch/shareddata/dldata/huggingface-hub-cache/"
local pyannote_cache = hf_home .. "hub/"
local torch_home = "/scratch/shareddata/speech2text"
local pyannote_config = "/share/apps/manual_installations/speech2text/" .. version .. "/pyannote/config.yml"
local numba_cache = "/tmp"
local mplconfigdir = "/tmp"

pushenv("HF_HOME", hf_home)
pushenv("PYANNOTE_CACHE", pyannote_cache)
pushenv("TORCH_HOME", torch_home)
pushenv("XDG_CACHE_HOME", torch_home)
pushenv("PYANNOTE_CONFIG", pyannote_config)
pushenv("NUMBA_CACHE_DIR", numba_cache)
pushenv("MPLCONFIGDIR", mplconfigdir)

local speech2text_mem = "8G"
local speech2text_cpus_per_task = "6"
local speech2text_tmp = os.getenv("WRKDIR") .. "/.speech2text"

pushenv("SPEECH2TEXT_MEM", speech2text_mem)
pushenv("SPEECH2TEXT_CPUS_PER_TASK", speech2text_cpus_per_task)
pushenv("SPEECH2TEXT_TMP", speech2text_tmp)

pushenv("HF_HUB_OFFLINE", "1")

if mode() == "load" then
LmodMessage("For more information, run 'module spider speech2text/" .. version .. "'")
end

4 changes: 3 additions & 1 deletion src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,7 @@
"welsh": "cy",
}


supported_languages_reverse = {value: key for key, value in supported_languages.items()}

available_whisper_models = ["large-v2", "large-v3"]
default_whisper_model = "large-v3"
51 changes: 44 additions & 7 deletions src/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import whisperx
from numba.core.errors import (NumbaDeprecationWarning,
NumbaPendingDeprecationWarning)
from pydub import AudioSegment
from whisperx.types import TranscriptionResult

import settings
Expand Down Expand Up @@ -76,6 +74,12 @@ def get_argument_parser():
default=os.getenv("SPEECH2TEXT_LANGUAGE"),
help="Audio language. Optional but recommended.",
)
parser.add_argument(
"--SPEECH2TEXT_WHISPER_MODEL",
type=str,
default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"),
help=f"Whisper model. Defaults to {settings.default_whisper_model}.",
)

return parser

Expand Down Expand Up @@ -218,12 +222,18 @@ def write_alignment_to_txt_file(alignment, output_file_stem):


def load_whisperx_model(
name: str = "large-v3",
name: str,
device: Optional[Union[str, torch.device]] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

if name not in settings.available_whisper_models:
logger.warning(
f"Specified model '{name}' not among available models: {settings.available_whisper_models}. Opting to use the default model '{settings.default_whisper_model}' instead"
)
name = settings.default_whisper_model

compute_type = "float16" if device == "cuda" else "int8"
try:
model = whisperx.load_model(
Expand Down Expand Up @@ -254,9 +264,12 @@ def read_input_file_from_array_file(input_file, slurm_array_task_id):
return new_input_file


def transcribe(file: str, language: str, result_list) -> TranscriptionResult:
def transcribe(
file: str, model_name: str, language: str, result_list
) -> TranscriptionResult:
batch_size = calculate_max_batch_size()
model = load_whisperx_model()
model = load_whisperx_model(model_name)

try:
segs, _ = model.transcribe(
file, batch_size=batch_size, language=language
Expand Down Expand Up @@ -317,9 +330,32 @@ def main():

logger.info(f".. .. Wav conversion done in {time.time()-t0:.1f} seconds")

# Check Whisper model name if given
model_name = args.SPEECH2TEXT_WHISPER_MODEL
if model_name is None:
model_name = settings.default_whisper_model

# Check language if given
language = args.SPEECH2TEXT_LANGUAGE
if language and language.lower() in settings.supported_languages:
language = settings.supported_languages[language.lower()]
if language:
if language.lower() in settings.supported_languages.keys():
# Language is given in OK long form: convert to short form (two-letter abbreviation)
language = settings.supported_languages[language.lower()]
elif language.lower() in settings.supported_languages.values():
# Language is given in OK short form
pass
else:
# Given language not OK
pretty_language_list = ", ".join(
[
f"{lang} ({short})"
for lang, short in settings.supported_languages.items()
]
)
logger.warning(
f"Given language '{language}' not found among supported languages: {pretty_language_list}. Opting to detect language automatically"
)
language = None

with mp.Manager() as manager:
shared_dict = manager.dict()
Expand All @@ -328,6 +364,7 @@ def main():
target=transcribe,
args=(
input_file_wav,
model_name,
language,
shared_dict,
),
Expand Down
31 changes: 30 additions & 1 deletion src/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
import shlex
import subprocess
import time
from pathlib import Path, PosixPath

import settings
Expand Down Expand Up @@ -62,6 +61,12 @@ def get_argument_parser():
default=os.getenv("SPEECH2TEXT_LANGUAGE"),
help="Language. Optional.",
)
parser.add_argument(
"--SPEECH2TEXT_WHISPER_MODEL",
type=str,
default=os.getenv("SPEECH2TEXT_WHISPER_MODEL"),
help=f"Whisper model. Default is {settings.default_whisper_model}.",
)

return parser

Expand Down Expand Up @@ -195,6 +200,8 @@ def create_sbatch_script_for_array_job(
#SBATCH --mail-type=BEGIN
#SBATCH --mail-type=END
#SBATCH --mail-type=FAIL
export OMP_NUM_THREADS={cpus_per_task}
export KMP_AFFINITY=granularity=fine,compact
python3 {python_source_dir}/speech2text.py {input_file}
"""
tmp_file_sh = (Path(tmp_dir) / str(job_name)).with_suffix(".sh")
Expand Down Expand Up @@ -327,6 +334,24 @@ def check_email(email):
)


def check_whisper_model(name):
if name is None:
print(
f"Whisper model not given, using default '{settings.default_whisper_model}'.\n"
)
return True

elif name in settings.available_whisper_models:
print(f"Given Whisper model '{name}' is available.\n")
return True

print(
f"Submission failed: Given Whisper model '{name}' is not among available models:\n\n{' '.join(settings.available_whisper_models)}.\n"
)

return False


def main():
# Parse arguments
parser = get_argument_parser()
Expand All @@ -343,6 +368,10 @@ def main():
# Check email
check_email(args.SPEECH2TEXT_EMAIL)

# Check Whisper model name
if not check_whisper_model(args.SPEECH2TEXT_WHISPER_MODEL):
return

# Notify about temporary folder location
print(
f"Log files (.out) and batch submit scripts (.sh) will be written to: {args.SPEECH2TEXT_TMP}\n"
Expand Down

0 comments on commit cfb5ba3

Please sign in to comment.