-
Notifications
You must be signed in to change notification settings - Fork 977
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switching to pipeline for HF whisper (#814)
- Loading branch information
1 parent
cf340bc
commit 3d8f5da
Showing
4 changed files
with
44 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,61 @@ | ||
import sys | ||
import logging | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
import whisper | ||
import torch | ||
from transformers import WhisperProcessor, WhisperForConditionalGeneration | ||
|
||
def cuda_is_viable(min_vram_gb=10): | ||
if not torch.cuda.is_available(): | ||
return False | ||
|
||
total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # Convert bytes to GB | ||
if total_memory < min_vram_gb: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def load_model(model_name_or_path: str): | ||
processor = WhisperProcessor.from_pretrained(model_name_or_path) | ||
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path) | ||
|
||
if cuda_is_viable(): | ||
logging.debug("CUDA is available and has enough VRAM, moving model to GPU.") | ||
model.to("cuda") | ||
|
||
return TransformersWhisper(processor, model) | ||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||
|
||
|
||
class TransformersWhisper: | ||
def __init__( | ||
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration | ||
self, model_id: str | ||
): | ||
self.processor = processor | ||
self.model = model | ||
self.SAMPLE_RATE = whisper.audio.SAMPLE_RATE | ||
self.N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES | ||
self.model_id = model_id | ||
|
||
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and | ||
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887, | ||
# https://github.com/huggingface/transformers/pull/20620. | ||
def transcribe( | ||
self, | ||
audio: Union[str, np.ndarray], | ||
language: str, | ||
task: str, | ||
verbose: Optional[bool] = None, | ||
): | ||
if isinstance(audio, str): | ||
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE) | ||
|
||
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids( | ||
task=task, language=language | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||
|
||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||
self.model_id, torch_dtype=torch_dtype, use_safetensors=True | ||
) | ||
|
||
segments = [] | ||
all_predicted_ids = [] | ||
model.generation_config.language = language | ||
model.to(device) | ||
|
||
num_samples = audio.size | ||
seek = 0 | ||
with tqdm( | ||
total=num_samples, unit="samples", disable=verbose is not False | ||
) as progress_bar: | ||
while seek < num_samples: | ||
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK] | ||
input_features = self.processor( | ||
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE | ||
).input_features.to(self.model.device) | ||
predicted_ids = self.model.generate(input_features) | ||
all_predicted_ids.extend(predicted_ids) | ||
text: str = self.processor.batch_decode( | ||
predicted_ids, skip_special_tokens=True | ||
)[0] | ||
if text.strip() != "": | ||
segments.append( | ||
{ | ||
"start": seek / self.SAMPLE_RATE, | ||
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) | ||
/ self.SAMPLE_RATE, | ||
"text": text, | ||
} | ||
) | ||
processor = AutoProcessor.from_pretrained(self.model_id) | ||
|
||
pipe = pipeline( | ||
"automatic-speech-recognition", | ||
generate_kwargs={"language": language, "task": task}, | ||
model=model, | ||
tokenizer=processor.tokenizer, | ||
feature_extractor=processor.feature_extractor, | ||
chunk_length_s=30, | ||
torch_dtype=torch_dtype, | ||
device=device, | ||
) | ||
|
||
progress_bar.update( | ||
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek | ||
) | ||
seek += self.N_SAMPLES_IN_CHUNK | ||
transcript = pipe(audio, return_timestamps=True) | ||
|
||
segments = [] | ||
for chunk in transcript['chunks']: | ||
start, end = chunk['timestamp'] | ||
text = chunk['text'] | ||
segments.append({ | ||
"start": start, | ||
"end": end, | ||
"text": text, | ||
"translation": "" | ||
}) | ||
|
||
return { | ||
"text": self.processor.batch_decode( | ||
all_predicted_ids, skip_special_tokens=True | ||
)[0], | ||
"text": transcript['text'], | ||
"segments": segments, | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters