Skip to content

Commit

Permalink
Upstream bug fixes and use Whisper Large V3 by default
Browse files Browse the repository at this point in the history
  • Loading branch information
kylemclaren committed Jul 3, 2024
1 parent 7095731 commit cfdab9b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 35 deletions.
4 changes: 4 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ build:
- "openai-whisper==20231106"
- ipython

# commands run after the environment is setup
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
124 changes: 89 additions & 35 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# License disclaimer - this code has been modified from the original version to fix a standing bug in the original code

import io
import os
from typing import Optional, Any
import os
import time
import subprocess
import torch
import numpy as np
import cProfile
import pstats
from pstats import SortKey
import time

from cog import BasePredictor, Input, Path, BaseModel

import whisper
from whisper.model import Whisper, ModelDimensions
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.utils import format_timestamp

MODEL_CACHE = "weights"
BASE_URL = f"https://weights.replicate.delivery/default/whisper-v3/{MODEL_CACHE}/"


class Output(BaseModel):
class ModelOutput(BaseModel):
detected_language: str
transcription: str
segments: Any
Expand All @@ -27,24 +23,69 @@ class Output(BaseModel):
srt_file: Optional[Path]


def download_weights(url: str, dest: str) -> None:
start = time.time()
print("[!] Initiating download from URL: ", url)
print("[~] Destination path: ", dest)
if ".tar" in dest:
dest = os.path.dirname(dest)
command = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest]
try:
print(f"[~] Running command: {' '.join(command)}")
subprocess.check_call(command, close_fds=False)
except subprocess.CalledProcessError as e:
print(
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
)
raise
print("[+] Download completed in: ", time.time() - start, "seconds")


class Predictor(BasePredictor):
def setup(self):
"""Loads whisper models into memory to make running multiple predictions efficient"""

with open(f"./weights/large-v2.pt", "rb") as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
self.model = Whisper(dims)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to("cuda")
"""Load the large-v3 model"""
self.model_cache = MODEL_CACHE
self.models = {}
self.current_model = "large-v3"
self.load_model("large-v3")

def load_model(self, model_name):
if model_name not in self.models:
if not os.path.exists(self.model_cache):
os.makedirs(self.model_cache)

model_file = f"{model_name}.pt"
url = BASE_URL + model_file
dest_path = os.path.join(self.model_cache, model_file)

if not os.path.exists(dest_path):
download_weights(url, dest_path)

with open(dest_path, "rb") as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
model.to("cuda")

self.models[model_name] = model
self.current_model = model_name
return self.models[model_name]

def predict(
self,
audio: Path = Input(description="Audio file"),
# Note: We only serve the large-v3 model to reduce switching costs and because it meets most users' needs.
# Other model sizes (base, small, tiny) are commented out as they're not currently offered.
model: str = Input(
default="large-v2",
choices=["large", "large-v2"],
description="Choose a Whisper model.",
choices=[
"large-v3",
# "base",
# "small",
# "tiny",
],
default="large-v3",
description="Whisper model size (currently only large-v3 is supported).",
),
transcription: str = Input(
choices=["plain text", "srt", "vtt"],
Expand All @@ -56,10 +97,11 @@ def predict(
description="Translate the text to English when set to True",
),
language: str = Input(
choices=sorted(LANGUAGES.keys())
choices=["auto"]
+ sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
default=None,
description="language spoken in the audio, specify None to perform language detection",
default="auto",
description="Language spoken in the audio, specify 'auto' for automatic language detection",
),
temperature: float = Input(
default=0,
Expand Down Expand Up @@ -96,20 +138,29 @@ def predict(
no_speech_threshold: float = Input(
default=0.6,
description="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence",
)
) -> Output:
),
) -> ModelOutput:
"""Transcribes and optionally translates a single audio file"""
print(f"Transcribe with {model} model")
model = self.model
print(f"Transcribe with {model} model.")

if model != self.current_model:
self.model = self.load_model(model)
else:
self.model = self.models[self.current_model]

if temperature_increment_on_fallback is not None:
temperature = tuple(
np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)
)
else:
temperature = [temperature]

normalized_language = language.lower() if language.lower() != "auto" else None
if normalized_language and normalized_language not in LANGUAGES:
normalized_language = TO_LANGUAGE_CODE.get(normalized_language, normalized_language)

args = {
"language": language,
"language": normalized_language,
"patience": patience,
"suppress_tokens": suppress_tokens,
"initial_prompt": initial_prompt,
Expand All @@ -118,10 +169,10 @@ def predict(
"logprob_threshold": logprob_threshold,
"no_speech_threshold": no_speech_threshold,
"fp16": True,
"verbose": False
"verbose": False,
}
with torch.inference_mode():
result = model.transcribe(str(audio), temperature=temperature, **args)
result = self.model.transcribe(str(audio), temperature=temperature, **args)

if transcription == "plain text":
transcription = result["text"]
Expand All @@ -131,13 +182,16 @@ def predict(
transcription = write_vtt(result["segments"])

if translate:
translation = model.transcribe(
translation = self.model.transcribe(
str(audio), task="translate", temperature=temperature, **args
)

return Output(
detected_language_code = result["language"]
detected_language_name = LANGUAGES.get(detected_language_code, detected_language_code)

return ModelOutput(
segments=result["segments"],
detected_language=LANGUAGES[result["language"]],
detected_language=detected_language_name,
transcription=transcription,
translation=translation["text"] if translate else None,
)
Expand Down

0 comments on commit cfdab9b

Please sign in to comment.