-
-
Notifications
You must be signed in to change notification settings - Fork 245
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests, workflow fixes + torch verison
- Loading branch information
Kye
committed
Nov 11, 2023
1 parent
e8e024f
commit 2e6efb4
Showing
10 changed files
with
1,748 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,4 +70,4 @@ rich | |
|
||
mkdocs | ||
mkdocs-material | ||
mkdocs-glightbox | ||
mkdocs-glightbox |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,125 @@ | ||
"""An ultra fast speech to text model.""" | ||
# speech to text tool | ||
|
||
import os | ||
import subprocess | ||
|
||
import whisperx | ||
from pydub import AudioSegment | ||
from pytube import YouTube | ||
|
||
|
||
class WhisperX: | ||
def __init__( | ||
self, | ||
video_url, | ||
audio_format="mp3", | ||
device="cuda", | ||
batch_size=16, | ||
compute_type="float16", | ||
hf_api_key=None, | ||
): | ||
""" | ||
# Example usage | ||
video_url = "url" | ||
speech_to_text = SpeechToText(video_url) | ||
transcription = speech_to_text.transcribe_youtube_video() | ||
print(transcription) | ||
""" | ||
self.video_url = video_url | ||
self.audio_format = audio_format | ||
self.device = device | ||
self.batch_size = batch_size | ||
self.compute_type = compute_type | ||
self.hf_api_key = hf_api_key | ||
|
||
def install(self): | ||
subprocess.run(["pip", "install", "whisperx"]) | ||
subprocess.run(["pip", "install", "pytube"]) | ||
subprocess.run(["pip", "install", "pydub"]) | ||
|
||
def download_youtube_video(self): | ||
audio_file = f"video.{self.audio_format}" | ||
|
||
# Download video 📥 | ||
yt = YouTube(self.video_url) | ||
yt_stream = yt.streams.filter(only_audio=True).first() | ||
yt_stream.download(filename="video.mp4") | ||
|
||
# Convert video to audio 🎧 | ||
video = AudioSegment.from_file("video.mp4", format="mp4") | ||
video.export(audio_file, format=self.audio_format) | ||
os.remove("video.mp4") | ||
|
||
return audio_file | ||
|
||
def transcribe_youtube_video(self): | ||
audio_file = self.download_youtube_video() | ||
|
||
device = "cuda" | ||
batch_size = 16 | ||
compute_type = "float16" | ||
|
||
# 1. Transcribe with original Whisper (batched) 🗣️ | ||
model = whisperx.load_model("large-v2", device, compute_type=compute_type) | ||
audio = whisperx.load_audio(audio_file) | ||
result = model.transcribe(audio, batch_size=batch_size) | ||
|
||
# 2. Align Whisper output 🔍 | ||
model_a, metadata = whisperx.load_align_model( | ||
language_code=result["language"], device=device | ||
) | ||
result = whisperx.align( | ||
result["segments"], | ||
model_a, | ||
metadata, | ||
audio, | ||
device, | ||
return_char_alignments=False, | ||
) | ||
|
||
# 3. Assign speaker labels 🏷️ | ||
diarize_model = whisperx.DiarizationPipeline( | ||
use_auth_token=self.hf_api_key, device=device | ||
) | ||
diarize_model(audio_file) | ||
|
||
try: | ||
segments = result["segments"] | ||
transcription = " ".join(segment["text"] for segment in segments) | ||
return transcription | ||
except KeyError: | ||
print("The key 'segments' is not found in the result.") | ||
|
||
def transcribe(self, audio_file): | ||
model = whisperx.load_model("large-v2", self.device, self.compute_type) | ||
audio = whisperx.load_audio(audio_file) | ||
result = model.transcribe(audio, batch_size=self.batch_size) | ||
|
||
# 2. Align Whisper output 🔍 | ||
model_a, metadata = whisperx.load_align_model( | ||
language_code=result["language"], device=self.device | ||
) | ||
|
||
result = whisperx.align( | ||
result["segments"], | ||
model_a, | ||
metadata, | ||
audio, | ||
self.device, | ||
return_char_alignments=False, | ||
) | ||
|
||
# 3. Assign speaker labels 🏷️ | ||
diarize_model = whisperx.DiarizationPipeline( | ||
use_auth_token=self.hf_api_key, device=self.device | ||
) | ||
|
||
diarize_model(audio_file) | ||
|
||
try: | ||
segments = result["segments"] | ||
transcription = " ".join(segment["text"] for segment in segments) | ||
return transcription | ||
except KeyError: | ||
print("The key 'segments' is not found in the result.") |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.