Skip to content

Commit

Permalink
whisperX Integration (#3)
Browse files Browse the repository at this point in the history
* add whisperx dependencies

Forcing torch related packages to download from pytorch channel, otherwise CPU compatible versions are installed.

* add get length function for audio file

* add whisperX compatibility

* added diarization function

* Added diarization pipeline

Resampling and converting to wav is done in the pipeline. Now every file format is compatible with offline diarization

* add required dependencies

* add parallel computation

now transcription and dialization can be run in parallel which almost doubles the speech

* better logging

logging for multiprocessing threads can be challenging. Removing the logger from the function ensures that the output will come in order for less confusing output.

* fix get audio lenght

Some input files had corrupted headers which caused ffprobe to return N/A for audio duration. Now the duration is fetched from ffmpeg down sampling function

* added required tag for input file

* fix diarize time error

* added support for GPU sbatch script

* optimization of batch size for each GPU

* increase upper limit estimation

With high CPU usage in some node, loading the pipeline takes more time.

* add support for float32 GPU

Tesla P100 GPU only supports float32. Single precision calculations requires more VRAM, therefore the batch_size needs to be modified.

* DOC: clean comments, typos

* DOC: update readme and user guide

* fix itertation over PosixPath bug

* black + isort

---------

Co-authored-by: Teemu Ruokolainen <[email protected]>
  • Loading branch information
hsnfirooz and ruokolt authored Jan 30, 2024
1 parent 9fbc1c3 commit a67f271
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 148 deletions.
11 changes: 1 addition & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# speech2text

This repo contains instructions for setting up and applying the speech2text app on Aalto Triton cluster. The app utilizes [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) automatic speech recognition tool and [Pyannote](https://huggingface.co/pyannote/speaker-diarization) speaker detection (diarization) pipeline. The speech recognition and diarization steps are run sequentially (and independently) and their result segments are combined (aligned) using a simple algorithm which for each transcription segment finds the most overlapping (in time) speaker segment.
This repo contains instructions for setting up and applying the speech2text app on Aalto Triton cluster. The app utilizes [WhisperX](https://github.com/m-bain/whisperX) automatic speech recognition tool and [Pyannote](https://huggingface.co/pyannote/speaker-diarization) speaker detection (diarization) pipeline. The speech recognition and diarization steps are run independently and their result segments are combined (aligned) using a simple algorithm which for each transcription segment finds the most overlapping (in time) speaker segment.

The required models are described [here](#models).

Expand Down Expand Up @@ -142,7 +142,6 @@ MPLCONFIGDIR
SPEECH2TEXT_TMP
SPEECH2TEXT_MEM
SPEECH2TEXT_CPUS_PER_TASK
SPEECH2TEXT_TIME
```

Note that you can leave the language variable unspecified, in which case speech2text tries to detect the language automatically. Specifying the language explicitly is, however, recommended.
Expand Down Expand Up @@ -259,14 +258,6 @@ The documentation can be found in `docs/build/`. A good place to start is the in

## Known Issues

### Inference using CPUs versus GPUs

The recommended way to do inference with Whisper is to use GPUs. However, on Triton, we have to make a compromise between GPU queues and inference efficiency. All the scripts use CPUs by default.

### Increasing the number of CPUs for inference

There is a plateauing problem with running Whisper inference with multiple CPUs (not GPUs). Increasing the number of CPUs speeds up inference until around 8 CPUs but plateaus and begins to slow down after 16. See related discussion where same behavior has been observed: [https://github.com/ggerganov/whisper.cpp/issues/200](https://github.com/ggerganov/whisper.cpp/issues/200) Therefore, in all the scripts, the number of CPUs is set to 8 by default.

### Audio files with more than one language

If a single audio file contains speech in more than one language, result files will (probably) still be produced but the results will (probably) be nonsensical to some extent. This is because even when using automatic language detection, Whisper appears to [detect the first language it encounters (if not given specifically) and stick to it until the end of the audio file, translating other encountered languages to the first language](https://github.com/openai/whisper/discussions/49).
Expand Down
16 changes: 3 additions & 13 deletions docs/source/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ Go to [Open On Demand](http://ood.triton.aalto.fi) and log in with your Aalto us

## Copy your data to Triton

> **_NOTE:_** If you are familiar with Triton usage, feel free to use any of the available approaches [to connect to Triton](https://scicomp.aalto.fi/triton/ref/#connecting) and [to transfer you data](https://scicomp.aalto.fi/triton/tut/remotedata/) and skip directly to next section.
On the Open On Demand front page, click the `Files` dropdown menu from the left upper corner and select `Work /scratch/work/yourusername`.

![](images/files_workspace.png)
Expand Down Expand Up @@ -158,7 +156,7 @@ These variables are valid until the end of current terminal session.
>```bash
>unset SPEECH2TEXT_EMAIL
>```
> This is equal to not running the `export [email protected] command` command in the first place. However, receiving the notifications is recommended.
> This is equal to not running the `export [email protected]` command in the first place. However, receiving the notifications is recommended.
Finally, submit all the audio files in your folder to the Triton job queue (remember to replace `my-audio-folder` with the name of the folder you just uploaded) with
```
Expand Down Expand Up @@ -205,7 +203,7 @@ slurm queue
```
which tells you for each job ID if the job is still in the queue waiting for resources (_PENDING_) or already running (_RUNNING_).

>**_NOTE:_** As a rule of thumb, the results will be ready at the next day latest. However, if you receive an email saying the processing has failed or have not received any emails within, say, an hour of running the speech2text command, something has gone wrong. In this case, visit RSEs at [the daily Zoom help session at 13.00-14.00](https://scicomp.aalto.fi/help/garage/#id1) and we will figure it out.
>**_NOTE:_** As a rule of thumb, you can expected to results to be ready within an hour. However, if you receive an email saying the processing has failed or have not received any emails within an hour of running the speech2text command, something has gone wrong. In this case, visit RSEs at [the daily Zoom help session at 13.00-14.00](https://scicomp.aalto.fi/help/garage/#id1) and we will figure it out.
If you have no more work to submit at this time, you are free to close the terminal window and log out from Open On Demand. If the browser asks for confirmation (`This page is asking you to confirm that you want to leave — information you’ve entered may not be saved.`), you can answer "yes".

Expand Down Expand Up @@ -297,15 +295,7 @@ If you do not need your audio and/or result files and/or folders, you can remove
### My transcription has a weird segment where a word or two are repeated over and over.

This is a quite known issue with the OpenAI Whisper speech recognition model. This behavior is sometimes triggered
by bad audio quality during that segment (background noise, mic issues, people talking over each other). However, sometimes this seems to happen even with good audio quality. Unfortunately, there is nothing we can do about this at the moment: you have to go through that particular audio segment and transcribe it manually.

### My speech2text process ran over night and I got noted that the job failed due to time limit.

The run time of speech2text on a single audio file is limited to 24 hours by default. If you have very large audio files (several hours), you can try setting the maximum run time to a larger value, e.g. 72 hours, with
```
export SPEECH2TEXT_TIME=72:00:00
```
Run the speech2text on your file/folder again normally according to the guide [above](#run-speech2text-on-triton).
by bad audio quality during that segment (background noise, mic issues, people talking over each other). However, sometimes this seems to happen even with good audio quality. Unfortunately, there is nothing we can do about this at the moment: you have to go through that particular audio segment and transcribe it manually.

### I accidentally closed the browser tab/window when speech2text was still running.

Expand Down
9 changes: 6 additions & 3 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ name: speech2text
channels:
- conda-forge
- pytorch
- nvidia
dependencies:
- git
- ffmpeg
- libsndfile
- python=3.10
- pydub
- pytorch
- torchvision
- pytorch::pytorch
- pytorch::torchvision
- pytorch::torchaudio
- pytorch::pytorch-cuda==11.8
- pip
- pip:
- pyannote.audio
- faster-whisper
- whisperx @ git+https://github.com/m-bain/whisperx.git
4 changes: 2 additions & 2 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
"ukrainian": "uk",
"urdu": "ur",
"vietnamese": "vi",
"welsh": "cy"
"welsh": "cy",
}


supported_languages_reverse = {value : key for key, value in supported_languages.items()}
supported_languages_reverse = {value: key for key, value in supported_languages.items()}
206 changes: 107 additions & 99 deletions src/speech2text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import gc
import json
import logging
import os
Expand All @@ -7,19 +8,21 @@
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union
from pydub import AudioSegment

import faster_whisper
import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import whisperx
from numba.core.errors import (NumbaDeprecationWarning,
NumbaPendingDeprecationWarning)
from pyannote.audio import Pipeline

from submit import parse_output_dir
from utils import seconds_to_human_readable_format
from pydub import AudioSegment
from whisperx.types import TranscriptionResult

import settings
from submit import parse_output_dir
from utils import (DiarizationPipeline, calculate_max_batch_size, load_audio,
seconds_to_human_readable_format)

# https://numba.pydata.org/numba-doc/dev/reference/deprecation.html
warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
Expand All @@ -34,6 +37,10 @@
logger = logging.getLogger("__name__")
logging.getLogger("faster_whisper").setLevel(logging.WARNING)

# Sharing CUDA tensors between prcoesses requires a spawn or forkserver start method
if __name__ == "__main__":
mp.set_start_method("spawn")


def get_argument_parser():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -116,11 +123,10 @@ def align(segments, diarization):
dict
"""
transcription_segments = [
(segment.start, segment.end, segment.text) for segment in segments
(segment["start"], segment["end"], segment["text"]) for segment in segments
]
diarization_segments = [
(segment.start, segment.end, speaker)
for segment, _, speaker in diarization.itertracks(yield_label=True)
(start, end, speaker) for _, _, speaker, start, end in diarization.to_numpy()
]
alignment = defaultdict(list)
for transcription_start, transcription_end, text in transcription_segments:
Expand Down Expand Up @@ -211,45 +217,30 @@ def write_alignment_to_txt_file(alignment, output_file_stem):
logger.info(f".. .. Wrote TXT output to: {output_file}")


def load_faster_whisper_model(
def load_whisperx_model(
name: str = "large-v3",
device: Optional[Union[str, torch.device]] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = faster_whisper.WhisperModel(
name,
device=device,
cpu_threads=6,
compute_type="int8",
)

return model


def load_diarization_pipeline(config_file, auth_token):
"""
For more info on the config file, see 'Offline use' at:
https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/applying_a_pipeline.ipynb
"""

if Path(config_file).is_file():
logger.info(".. .. Local config file found")
pipeline = Pipeline.from_pretrained(config_file)
elif auth_token:
logger.info(".. .. Environment variable AUTH_TOKEN found")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=auth_token,
compute_type = "float16" if device == "cuda" else "int8"
try:
model = whisperx.load_model(
name,
device=device,
threads=6,
compute_type=compute_type,
)
else:
logger.error(
"One of these is required: local pyannote config file or environment variable AUTH_TOKEN to download model from HuggingFace hub"
except ValueError:
compute_type = "float32"
model = whisperx.load_model(
name,
device=device,
threads=6,
compute_type=compute_type,
)
raise ValueError

return pipeline
return model


def read_input_file_from_array_file(input_file, slurm_array_task_id):
Expand All @@ -263,33 +254,36 @@ def read_input_file_from_array_file(input_file, slurm_array_task_id):
return new_input_file


def convert_to_wav(input_file, tmp_dir):
"""Pyannote diarization pipeline does handle resampling to ensure 16 kHz and
stereo/mono mixing. However, number of supported audio/video formats appears to be
limited and not listed in README. To be sure, we convert all files to .wav beforehand.
def transcribe(file: str, language: str, result_list) -> TranscriptionResult:
batch_size = calculate_max_batch_size()
model = load_whisperx_model()
try:
segs, _ = model.transcribe(
file, batch_size=batch_size, language=language
).values()
except RuntimeError:
logger.warning(
f"Current CUDA device {torch.cuda.current_device()} doesn't have enough memory. Reducing batch_size {batch_size} by half."
)

https://huggingface.co/pyannote/speaker-diarization-3.1
"""
gc.collect()
torch.cuda.empty_cache()

if str(input_file).lower().endswith(".wav"):
logger.info(f".. .. File is already in wav format: {input_file}")
return input_file
batch_size /= 2
segs, _ = model.transcribe(
file, batch_size=int(batch_size), language=language
).values()

if not Path(input_file).is_file():
logger.info(f".. .. File does not exist: {input_file}")
return None
result_list["segments"] = segs
result_list["transcribe_time"] = time.time()

converted_file = Path(tmp_dir) / Path(Path(input_file).name).with_suffix(".wav")
if Path(converted_file).is_file():
logger.info(f".. .. Converted file {converted_file} already exists.")
return converted_file
try:
AudioSegment.from_file(input_file).export(converted_file, format="wav")
logger.info(f".. .. File converted to wav: {converted_file}")
return converted_file
except Exception as err:
logger.info(f".. .. Error while converting file: {err}")
return None

def diarization(file: str, config: str, token: str, result_list):
diarization_pipeline = DiarizationPipeline(config_file=config, auth_token=token)
diarization = diarization_pipeline(file)

result_list["diarization"] = diarization
result_list["diarize_time"] = time.time()


def main():
Expand All @@ -311,59 +305,73 @@ def main():
)

# .wav conversion
logger.info(f".. Convert input file to wav format for pyannote diarization pipeline: {args.INPUT_FILE}")
logger.info(
f".. Convert input file to wav format for pyannote diarization pipeline: {args.INPUT_FILE}"
)
t0 = time.time()
input_file_wav = convert_to_wav(args.INPUT_FILE, args.SPEECH2TEXT_TMP)
if input_file_wav is None:
try:
input_file_wav, _ = load_audio(args.INPUT_FILE)
except Exception as e:
logger.error(f".. .. Input file could not be converted: {args.INPUT_FILE}")
return
logger.info(f".. .. Wav conversion done in {time.time()-t0:.1f} seconds")

# Diarization
logger.info(".. Load diarization pipeline")
t0 = time.time()
diarization_pipeline = load_diarization_pipeline(args.PYANNOTE_CONFIG, args.AUTH_TOKEN)
logger.info(f".. .. Pipeline loaded in {time.time()-t0:.1f} seconds")
raise (e)

logger.info(f".. Diarize input file: {input_file_wav}")
t0 = time.time()
diarization = diarization_pipeline(input_file_wav)
logger.info(f".. .. Diarization finished in {time.time()-t0:.1f} seconds")

# Transcription
logger.info(".. Load faster_whisper model")
t0 = time.time()
faster_whisper_model = load_faster_whisper_model()
logger.info(f".. .. Model loaded in {time.time()-t0:.1f} seconds")
logger.info(f".. .. Wav conversion done in {time.time()-t0:.1f} seconds")

logger.info(f".. Transcribe input file: {args.INPUT_FILE}")
t0 = time.time()
language = args.SPEECH2TEXT_LANGUAGE
if language and language.lower() in settings.supported_languages:
language = settings.supported_languages[language.lower()]
segments, info = faster_whisper_model.transcribe(
args.INPUT_FILE, language=language, beam_size=5
)
if language is None:
logger.info(f".. .. Automatically detected language '{settings.supported_languages_reverse[info.language]}' with probability {info.language_probability:.2f}")
segments = list(segments)
logger.info(f".. .. Transcription finished in {time.time()-t0:.1f} seconds")

with mp.Manager() as manager:
shared_dict = manager.dict()

process1 = mp.Process(
target=transcribe,
args=(
input_file_wav,
language,
shared_dict,
),
)
process2 = mp.Process(
target=diarization,
args=(
input_file_wav,
args.PYANNOTE_CONFIG,
args.AUTH_TOKEN,
shared_dict,
),
)

t0 = time.time()
logger.info(f".. Starting transcription task for {args.INPUT_FILE}")
process1.start()

logger.info(f".. Starting diarization task for {args.INPUT_FILE}")
process2.start()

process1.join()
process2.join()

logger.info(
f".. .. Transcription finished in {shared_dict['transcribe_time']-t0:.1f} seconds"
)
logger.info(
f".. .. Diarization finished in {shared_dict['diarize_time']-t0:.1f} seconds"
)

segments = shared_dict["segments"]
diarization_results = shared_dict["diarization"]

# Alignment
logger.info(".. Align transcription and diarization")
alignment = align(segments, diarization)
alignment = align(segments, diarization_results)

logger.info(f".. Write alignment to output")
output_dir = parse_output_dir(args.INPUT_FILE)
output_file_stem = parse_output_file_stem(output_dir, args.INPUT_FILE)
write_alignment_to_csv_file(alignment, output_file_stem)
write_alignment_to_txt_file(alignment, output_file_stem)

# Clean up
if input_file_wav != args.INPUT_FILE:
logger.info(f".. Remove the converted wav file")
Path(input_file_wav).unlink()

logger.info(f"Finished.")


Expand Down
Loading

0 comments on commit a67f271

Please sign in to comment.