Skip to content

Commit

Permalink
tests, workflow fixes + torch verison
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 11, 2023
1 parent e8e024f commit 2e6efb4
Show file tree
Hide file tree
Showing 10 changed files with 1,748 additions and 129 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ jobs:
with:
python-version: 3.x
- run: pip install mkdocs-material
- run: pip install mkdocs-glightbox
- run: pip install "mkdocstrings[python]"
- run: mkdocs gh-deploy --force
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ asyncio = "*"
nest_asyncio = "*"
einops = "*"
google-generativeai = "*"
torch = "*"
torch = "2.1.0"
langchain-experimental = "*"
playwright = "*"
duckduckgo-search = "*"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ rich

mkdocs
mkdocs-material
mkdocs-glightbox
mkdocs-glightbox
2 changes: 1 addition & 1 deletion swarms/models/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def is_lc_serializable(cls) -> bool:
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
# openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key = "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"
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
Expand Down
126 changes: 125 additions & 1 deletion swarms/models/whisperx.py
Original file line number Diff line number Diff line change
@@ -1 +1,125 @@
"""An ultra fast speech to text model."""
# speech to text tool

import os
import subprocess

import whisperx
from pydub import AudioSegment
from pytube import YouTube


class WhisperX:
def __init__(
self,
video_url,
audio_format="mp3",
device="cuda",
batch_size=16,
compute_type="float16",
hf_api_key=None,
):
"""
# Example usage
video_url = "url"
speech_to_text = SpeechToText(video_url)
transcription = speech_to_text.transcribe_youtube_video()
print(transcription)
"""
self.video_url = video_url
self.audio_format = audio_format
self.device = device
self.batch_size = batch_size
self.compute_type = compute_type
self.hf_api_key = hf_api_key

def install(self):
subprocess.run(["pip", "install", "whisperx"])
subprocess.run(["pip", "install", "pytube"])
subprocess.run(["pip", "install", "pydub"])

def download_youtube_video(self):
audio_file = f"video.{self.audio_format}"

# Download video 📥
yt = YouTube(self.video_url)
yt_stream = yt.streams.filter(only_audio=True).first()
yt_stream.download(filename="video.mp4")

# Convert video to audio 🎧
video = AudioSegment.from_file("video.mp4", format="mp4")
video.export(audio_file, format=self.audio_format)
os.remove("video.mp4")

return audio_file

def transcribe_youtube_video(self):
audio_file = self.download_youtube_video()

device = "cuda"
batch_size = 16
compute_type = "float16"

# 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)

# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False,
)

# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device
)
diarize_model(audio_file)

try:
segments = result["segments"]
transcription = " ".join(segment["text"] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")

def transcribe(self, audio_file):
model = whisperx.load_model("large-v2", self.device, self.compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=self.batch_size)

# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device
)

result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
self.device,
return_char_alignments=False,
)

# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device
)

diarize_model(audio_file)

try:
segments = result["segments"]
transcription = " ".join(segment["text"] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")
125 changes: 0 additions & 125 deletions swarms/tools/stt.py

This file was deleted.

Loading

0 comments on commit 2e6efb4

Please sign in to comment.