diff --git a/modules/diarize/audio_loader.py b/modules/diarize/audio_loader.py index 6db3cc64..2efd4216 100644 --- a/modules/diarize/audio_loader.py +++ b/modules/diarize/audio_loader.py @@ -2,6 +2,8 @@ import subprocess from functools import lru_cache from typing import Optional, Union +from scipy.io.wavfile import write +import tempfile import numpy as np import torch @@ -24,32 +26,43 @@ def exact_div(x, y): TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: """ - Open an audio file and read as mono waveform, resampling as necessary + Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. Parameters ---------- - file: str - The audio file to open + file: Union[str, np.ndarray] + The audio file to open or a numpy array containing the audio data. sr: int - The sample rate to resample the audio if necessary + The sample rate to resample the audio if necessary. Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ + if isinstance(file, np.ndarray): + if file.dtype != np.float32: + file = file.astype(np.float32) + if file.ndim > 1: + file = np.mean(file, axis=1) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) + temp_file_path = temp_file.name + temp_file.close() + else: + temp_file_path = file + try: - # Launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI to be installed. cmd = [ "ffmpeg", "-nostdin", "-threads", "0", "-i", - file, + temp_file_path, "-f", "s16le", "-ac", @@ -63,6 +76,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): out = subprocess.run(cmd, capture_output=True, check=True).stdout except subprocess.CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + finally: + if isinstance(file, np.ndarray): + os.remove(temp_file_path) return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index c12c2357..592e1bf7 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,7 @@ import os import torch -from typing import List +from typing import List, Union, BinaryIO +import numpy as np import time import logging @@ -20,7 +21,7 @@ def __init__(self, self.pipe = None def run(self, - audio: str, + audio: Union[str, BinaryIO, np.ndarray], transcribed_result: List[dict], use_auth_token: str, device: str