Skip to content

Commit

Permalink
Switched large model to v3-turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
tmoroney committed Dec 11, 2024
1 parent c4dda48 commit f5c346b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 41 deletions.
2 changes: 1 addition & 1 deletion AutoSubs-App/src/GlobalContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ export function GlobalProvider({ children }: React.PropsWithChildren<{}>) {
} else {
console.log(`Transcription Server INFO: "${line}"`);
}
} else if (line.includes("model.bin:")) {
} else if (line.includes("model.bin:") || line.includes("weights.safetensors:")) {
const percentageMatch = line.match(/(\d+)%/);
if (percentageMatch && percentageMatch[1]) {
const percentage = parseInt(percentageMatch[1], 10);
Expand Down
30 changes: 17 additions & 13 deletions AutoSubs-App/src/pages/home-page.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { useState } from "react"
import { useState } from "react";
import { platform } from '@tauri-apps/plugin-os';
import { cn } from "@/lib/utils"
import {
Bird,
Expand Down Expand Up @@ -605,7 +606,7 @@ export function HomePage() {
<p>
Whisper{" "}
<span className="font-medium text-foreground">
Large-V3
Large
</span>
</p>
<p className="text-xs" data-description>
Expand Down Expand Up @@ -683,18 +684,21 @@ export function HomePage() {
<Input value={maxChars} id="maxChars" type="number" placeholder="30" onChange={(e) => setMaxChars(Math.abs(Number.parseInt(e.target.value)))} />
</div>
</div>
<div className="flex items-center space-x-4 rounded-md border p-4">
<Pickaxe className="w-5" />
<div className="flex-1 space-y-1">
<p className="text-sm font-medium leading-none">
Force Align Words
</p>
<p className="text-xs text-muted-foreground">
Improve word level timing
</p>

{platform() === 'windows' && (
<div className="flex items-center space-x-4 rounded-md border p-4">
<Pickaxe className="w-5" />
<div className="flex-1 space-y-1">
<p className="text-sm font-medium leading-none">
Force Align Words
</p>
<p className="text-xs text-muted-foreground">
Improve word level timing
</p>
</div>
<Switch checked={alignWords} onCheckedChange={(checked) => setAlignWords(checked)} />
</div>
<Switch checked={alignWords} onCheckedChange={(checked) => setAlignWords(checked)} />
</div>
)}

</CardContent>
</Card>
Expand Down
63 changes: 36 additions & 27 deletions Transcription-Server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __getattr__(self, attr):

from huggingface_hub import HfApi, HfFolder, login, snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
import torch

if getattr(sys, 'frozen', False):
base_path = sys._MEIPASS
Expand Down Expand Up @@ -106,27 +107,34 @@ def __getattr__(self, attr):
"base": "base",
"small": "small",
"medium": "medium",
"large": "large-v3",
"large": "large-v3-turbo",
"tiny.en": "tiny.en",
"base.en": "base.en",
"small.en": "small.en",
"medium.en": "medium.en",
"large.en": "large-v3",
"large.en": "large-v3-turbo",
}

mac_models = {
"tiny": "mlx-community/whisper-tiny-mlx",
"base": "mlx-community/whisper-base-mlx-q4",
"small": "mlx-community/whisper-small-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large": "mlx-community/distil-whisper-large-v3",
"large": "mlx-community/whisper-large-v3-turbo",
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"large.en": "mlx-community/distil-whisper-large-v3",
"large.en": "mlx-community/whisper-large-v3-turbo",
"large.de": "mlx-community/whisper-large-v3-turbo-german-f16",
}

def sanitize_result(result):
# Convert the result to a JSON string
result_json = json.dumps(result, default=lambda o: None)
# Parse the JSON string back to a dictionary
sanitized_result = json.loads(result_json)
return sanitized_result

def is_model_cached_locally(model_id, revision=None):
try:
Expand Down Expand Up @@ -227,7 +235,6 @@ def transcribe_audio(audio_file, kwargs, max_words, max_chars, sensitive_words):
else:
result = stable_whisper.transcribe_any(
inference, audio_file, inference_kwargs=kwargs, vad=False, regroup=True)
result.pad()

result = modify_result(result, max_words, max_chars, sensitive_words)

Expand Down Expand Up @@ -429,18 +436,6 @@ class TranscriptionRequest(BaseModel):
async def transcribe(request: TranscriptionRequest):
try:
start_time = time.time()
if request.language == "en":
request.model = request.model + ".en"
task = "transcribe"
else:
task = request.task

if platform.system() == 'Windows':
model = win_models[request.model]
else:
model = mac_models[request.model]

print(model)

file_path = request.file_path
timeline = request.timeline
Expand All @@ -457,15 +452,6 @@ async def transcribe(request: TranscriptionRequest):
else:
print(f"Processing file: {file_path}")

import torch
kwargs = {
"model": model,
"task": task,
"language": request.language,
"align_words": request.align_words,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Select device
if torch.cuda.is_available():
device = torch.device("cuda")
Expand All @@ -476,6 +462,29 @@ async def transcribe(request: TranscriptionRequest):

print(f"Using device: {device}")

if request.language == "en":
request.model = request.model + ".en"
task = "transcribe"
elif request.language == "de" and request.model == "large" and platform.system() != 'Windows':
request.model = request.model + ".de"
else:
task = request.task

if platform.system() == 'Windows':
model = win_models[request.model]
else:
model = mac_models[request.model]

print(model)

kwargs = {
"model": model,
"task": task,
"language": request.language,
"align_words": request.align_words,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Process audio (transcription and optionally diarization)
try:
result = await process_audio(
Expand All @@ -498,7 +507,7 @@ async def transcribe(request: TranscriptionRequest):

# Save the transcription to a JSON file
with open(json_filepath, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=4, ensure_ascii=False)
json.dump(sanitize_result(result), f, indent=4, ensure_ascii=False)

print(f"Transcription saved to: {json_filepath}")
except Exception as e:
Expand Down

0 comments on commit f5c346b

Please sign in to comment.