Skip to content

Commit

Permalink
fix: bugs in test whisperx code
Browse files Browse the repository at this point in the history
  • Loading branch information
kurianbenoy committed Feb 21, 2024
1 parent d0daf04 commit 1c6f4c7
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions _experiments/test_whisperx.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
import whisperx
import gc

device = "cuda" audio_file = "Bishop Thomas Tharayil speaks about Dr Shashi Tharoor at Lourdes Forane Church Thiruvananthapuram.mp4"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type
device = "cuda"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)

# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)

# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)

audio = whisperx.load_audio(audio_file)
audio = whisperx.load_audio(
"Bishop Thomas Tharayil speaks about Dr Shashi Tharoor at Lourdes Forane Church Thiruvananthapuram.mp4"
) # noqa
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
print(result["segments"]) # before alignment

# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model

# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"], model_a, metadata, audio, device, return_char_alignments=False
)

print(result["segments"]) # after alignment
print(result["segments"]) # after alignment

# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a

# 3. Assign speaker labels
#diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)

# add min/max number of speakers if known
#diarize_segments = diarize_model(audio)
# diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

#result = whisperx.assign_word_speakers(diarize_segments, result)
#print(diarize_segments)
#print(result["segments"]) # segments are now assigned speaker IDs
# result = whisperx.assign_word_speakers(diarize_segments, result)
# print(diarize_segments)
# print(result["segments"]) # segments are now assigned speaker IDs

0 comments on commit 1c6f4c7

Please sign in to comment.