Skip to content

Commit

Permalink
lint: black + isort
Browse files Browse the repository at this point in the history
  • Loading branch information
ruokolt committed Mar 13, 2024
1 parent 4d0c41a commit f5c7a22
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 32 deletions.
10 changes: 7 additions & 3 deletions src/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,20 @@ def write_alignment_to_txt_file(alignment: dict, output_file_stem: Path):
def load_whisperx_model(
name: str,
language: Optional[str] = None,
device: Optional[Union[str, torch.device]] = 'cuda'):
device: Optional[Union[str, torch.device]] = "cuda",
):
"""
Load a Whisper model in GPU.
Will raise an error if CUDA is not available. This is due to batch_size optimization method in utils.py.
The submitted script will run on a GPU node, so this should not be a problem. The only issue is with a
The submitted script will run on a GPU node, so this should not be a problem. The only issue is with a
hardware failure.
"""
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check the hardware failures for " + subprocess.check_output(['hostname']).decode())
raise ValueError(
"CUDA is not available. Check the hardware failures for "
+ subprocess.check_output(["hostname"]).decode()
)

if name not in settings.available_whisper_models:
logger.warning(
Expand Down
56 changes: 28 additions & 28 deletions src/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
warnings.filterwarnings("ignore")

import argparse
from argparse import Namespace
import json
import os
import re
import shlex
import subprocess
from argparse import Namespace
from pathlib import Path, PosixPath

import settings
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_argument_parser():
return parser


def get_existing_result_files(input_file: str, output_dir: str) -> 'tuple[list, list]':
def get_existing_result_files(input_file: str, output_dir: str) -> "tuple[list, list]":
"""
For the input file or folder, check if the expected result files exist already in the output directory.
Expand Down Expand Up @@ -106,12 +106,12 @@ def get_existing_result_files(input_file: str, output_dir: str) -> 'tuple[list,
def parse_job_name(input_path: str) -> Path:
"""
Convert input file/folder to path object.
Parameters
----------
input_path: str
The input path for the audio files.
Returns
-------
Path
Expand All @@ -120,16 +120,15 @@ def parse_job_name(input_path: str) -> Path:
return Path(input_path).name


def parse_output_dir(input_path: str,
create_if_not_exists: bool = True) -> str:
def parse_output_dir(input_path: str, create_if_not_exists: bool = True) -> str:
"""
Create the output directory for the results.
Parameters
----------
input_path: str
The input path for the audio files.
Returns
-------
output_dir: str
Expand All @@ -148,10 +147,9 @@ def parse_output_dir(input_path: str,
return output_dir


def create_array_input_file(input_dir: str,
output_dir: str,
job_name: Path,
tmp_dir) -> str:
def create_array_input_file(
input_dir: str, output_dir: str, job_name: Path, tmp_dir
) -> str:
"""
Process the input directory and create a json file with the list of audio files to process.
Expand All @@ -175,14 +173,16 @@ def create_array_input_file(input_dir: str,
input_files = []
for input_file in Path(input_dir).glob("*.*"):
try:
result = subprocess.run(["ffmpeg", "-i", str(input_file)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run(
["ffmpeg", "-i", str(input_file)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
print(f"Error processing {input_file}: {e}")
continue
if "Audio:" not in str(result.stderr):
print(
f".. {input_file}: Skip since it's not an audio file."
)
print(f".. {input_file}: Skip since it's not an audio file.")
continue
existing, missing = get_existing_result_files(input_file, output_dir)
if existing and not missing:
Expand Down Expand Up @@ -259,13 +259,15 @@ def estimate_job_time(input_path: PosixPath) -> str:
return add_durations(PIPELINE_LOADING_TIME, audio_processing_time)


def create_sbatch_script_for_array_job(input_file: str,
job_name: Path,
mem: int,
cpus_per_task: int,
time: str,
email: str,
tmp_dir: str) -> str:
def create_sbatch_script_for_array_job(
input_file: str,
job_name: Path,
mem: int,
cpus_per_task: int,
time: str,
email: str,
tmp_dir: str,
) -> str:
"""
Create the sbatch script for the array job.
Expand Down Expand Up @@ -316,8 +318,7 @@ def create_sbatch_script_for_array_job(input_file: str,
return tmp_file_sh


def submit_dir(args: Namespace,
job_name: Path):
def submit_dir(args: Namespace, job_name: Path):
"""
Run sbatch command to submit the job to the cluster.
Expand Down Expand Up @@ -387,8 +388,7 @@ def create_sbatch_script_for_single_file(
return tmp_file_sh


def submit_file(args: Namespace,
job_name: Path):
def submit_file(args: Namespace, job_name: Path):
"""
Run sbatch command to submit the job to the cluster.
Expand Down Expand Up @@ -500,7 +500,7 @@ def check_whisper_model(name: str) -> bool:
----------
name: str
The Whisper model to check.
Returns
-------
Boolean:
Expand Down
2 changes: 1 addition & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def calculate_max_batch_size() -> int:
Parameters
----------
None
Returns
-------
batch_size:
Expand Down

0 comments on commit f5c7a22

Please sign in to comment.