Skip to content

Commit

Permalink
add type hints for code consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
Hossein Firooz committed Mar 8, 2024
1 parent 382b779 commit 35a018e
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 50 deletions.
133 changes: 109 additions & 24 deletions src/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import subprocess
import time
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -85,7 +86,26 @@ def get_argument_parser():
return parser


def compute_overlap(start1, end1, start2, end2):
def compute_overlap(start1: float, end1: float, start2: float, end2: float) -> float:
"""
Compute the overlap between two segments.
Parameters
----------
start1 : float
Start time of the first segment.
end1 : float
End time of the first segment.
start2 : float
Start time of the second segment.
end2 : float
End time of the second segment.
Returns
-------
float:
The overlap in time between the two segments.
"""
if start1 > end1 or start2 > end2:
raise ValueError("Start of segment can't be larger than its end.")

Expand All @@ -98,7 +118,7 @@ def compute_overlap(start1, end1, start2, end2):
return abs(end_overlap - start_overlap)


def align(segments, diarization):
def align(segments, diarization) -> dict:
"""
Align diarization with transcription.
Expand All @@ -107,15 +127,6 @@ def align(segments, diarization):
If no diarization segment overlaps with a given transcription segment, the speaker
for that transcription segment is None.
The output object is a dict of lists:
{
"start" : [0.0, 4.5, 7.0],
"end" : [3.3, 6.0, 10.0],
"transcription" : ["This is first first speaker segment", "This is the second", "This is from an unknown speaker"],
"speaker": ["SPEAKER_00", "SPEAKER_01", None]
}
Parameters
----------
transcription : list
Expand All @@ -125,7 +136,14 @@ def align(segments, diarization):
Returns
-------
dict
dict:
The output object is a dict of lists:
{
"start" : [0.0, 4.5, 7.0],
"end" : [3.3, 6.0, 10.0],
"transcription" : ["This is first first speaker segment", "This is the second", "This is from an unknown speaker"],
"speaker": ["SPEAKER_00", "SPEAKER_01", None]
}
"""
transcription_segments = [
(segment["start"], segment["end"], segment["text"]) for segment in segments
Expand Down Expand Up @@ -157,11 +175,24 @@ def align(segments, diarization):
return alignment


def parse_output_file_stem(output_dir, input_file):
def parse_output_file_stem(output_dir: str, input_file: str) -> Path:
"""
Create the output file from the input file and the output directory.
"""
return Path(output_dir) / Path(Path(input_file).name)


def write_alignment_to_csv_file(alignment, output_file_stem):
def write_alignment_to_csv_file(alignment: dict, output_file_stem: Path):
"""
Write the alignment to a CSV file.
Parameters
----------
alignment : dict
The alignment dictionary for start, end, speaker, and transcription.
output_file_stem : Path
The output file.
"""
df = pd.DataFrame.from_dict(alignment)
output_file = str(Path(output_file_stem).with_suffix(".csv"))
df.to_csv(
Expand All @@ -173,7 +204,17 @@ def write_alignment_to_csv_file(alignment, output_file_stem):
logger.info(f".. .. Wrote CSV output to: {output_file}")


def write_alignment_to_txt_file(alignment, output_file_stem):
def write_alignment_to_txt_file(alignment: dict, output_file_stem: Path):
"""
Write the alignment data to a text file.
Parameters
----------
alignment : dict
The alignment dictionary for start, end, speaker, and transcription.
output_file_stem : Path
The output file.
"""
# Group lines by speaker
all_lines_grouped_by_speaker = []
lines_speaker = []
Expand Down Expand Up @@ -224,21 +265,29 @@ def write_alignment_to_txt_file(alignment, output_file_stem):

def load_whisperx_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
language: Optional[str] = None,
device: Optional[Union[str, torch.device]] = 'cuda'):
"""
Load a Whisper model in GPU.
Will raise an error if CUDA is not available. This is due to batch_size optimization method in utils.py.
The submitted script will run on a GPU node, so this should not be a problem. The only issue is with a
hardware failure.
"""
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check the hardware failures for " + subprocess.check_output(['hostname']).decode())

if name not in settings.available_whisper_models:
logger.warning(
f"Specified model '{name}' not among available models: {settings.available_whisper_models}. Opting to use the default model '{settings.default_whisper_model}' instead"
)
name = settings.default_whisper_model

compute_type = "float16" if device == "cuda" else "int8"
compute_type = "float16"
try:
model = whisperx.load_model(
name,
language=language,
device=device,
threads=6,
compute_type=compute_type,
Expand All @@ -247,14 +296,20 @@ def load_whisperx_model(
compute_type = "float32"
model = whisperx.load_model(
name,
language=language,
device=device,
threads=6,
compute_type=compute_type,
)
return model


def read_input_file_from_array_file(input_file, slurm_array_task_id):
def read_input_file_from_array_file(input_file: str, slurm_array_task_id: str):
"""
Read a single audio path from a JSON file with an array of audio paths.
Returns the audio path at the given index.
"""
logger.info(f".. Read item {slurm_array_task_id} from {input_file}")
input_files = []
with open(input_file, "r") as fin:
Expand All @@ -266,15 +321,30 @@ def read_input_file_from_array_file(input_file, slurm_array_task_id):


def transcribe(
file: str, model_name: str, language: str, result_list
file: str, model_name: str, language: str, result_list: dict
) -> TranscriptionResult:
"""
The main transcription fucntion based on WhisperX.
Parameters
----------
file : str
The input audio file.
model_name : str
The Whisper model name.
language : str
The language of the audio. Not setting the language would result in automatic language detection.
result_list : dict
The dictionary to store the result.
"""
batch_size = calculate_max_batch_size()
model = load_whisperx_model(model_name)
model = load_whisperx_model(model_name, language)

try:
segs, _ = model.transcribe(
file, batch_size=batch_size, language=language
).values()
# If the batch size is too large, reduce it by half and try again to avoid CUDA memory error.
except RuntimeError:
logger.warning(
f"Current CUDA device {torch.cuda.current_device()} doesn't have enough memory. Reducing batch_size {batch_size} by half."
Expand All @@ -292,7 +362,21 @@ def transcribe(
result_list["transcribe_time"] = time.time()


def diarization(file: str, config: str, token: str, result_list):
def diarization(file: str, config: str, token: str, result_list: dict):
"""
The main diarization fucntion based on PYANNOTE model.
Parameters
----------
file : str
The input audio file.
config : str
Configutation for the PYANNOTE model.
token : str
To the the HF model if the config file is not available.
result_list : dict
The dictionary to store the result.
"""
diarization_pipeline = DiarizationPipeline(config_file=config, auth_token=token)
diarization = diarization_pipeline(file)

Expand Down Expand Up @@ -358,6 +442,7 @@ def main():
)
language = None

# Creating two seperate processes for transcription and diarization based on torch multiprocessing
with mp.Manager() as manager:
shared_dict = manager.dict()

Expand Down
Loading

0 comments on commit 35a018e

Please sign in to comment.