From f5c346b683a7cf2daa0fa92969318cb07cf8c067 Mon Sep 17 00:00:00 2001 From: Tom Moroney Date: Wed, 11 Dec 2024 23:41:50 +0000 Subject: [PATCH] Switched large model to v3-turbo --- AutoSubs-App/src/GlobalContext.tsx | 2 +- AutoSubs-App/src/pages/home-page.tsx | 30 +++++++------ Transcription-Server/server.py | 63 ++++++++++++++++------------ 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/AutoSubs-App/src/GlobalContext.tsx b/AutoSubs-App/src/GlobalContext.tsx index fd49d84..fc427c2 100644 --- a/AutoSubs-App/src/GlobalContext.tsx +++ b/AutoSubs-App/src/GlobalContext.tsx @@ -234,7 +234,7 @@ export function GlobalProvider({ children }: React.PropsWithChildren<{}>) { } else { console.log(`Transcription Server INFO: "${line}"`); } - } else if (line.includes("model.bin:")) { + } else if (line.includes("model.bin:") || line.includes("weights.safetensors:")) { const percentageMatch = line.match(/(\d+)%/); if (percentageMatch && percentageMatch[1]) { const percentage = parseInt(percentageMatch[1], 10); diff --git a/AutoSubs-App/src/pages/home-page.tsx b/AutoSubs-App/src/pages/home-page.tsx index d8c9e02..9c838b4 100644 --- a/AutoSubs-App/src/pages/home-page.tsx +++ b/AutoSubs-App/src/pages/home-page.tsx @@ -1,4 +1,5 @@ -import { useState } from "react" +import { useState } from "react"; +import { platform } from '@tauri-apps/plugin-os'; import { cn } from "@/lib/utils" import { Bird, @@ -605,7 +606,7 @@ export function HomePage() {

Whisper{" "} - Large-V3 + Large

@@ -683,18 +684,21 @@ export function HomePage() { setMaxChars(Math.abs(Number.parseInt(e.target.value)))} /> -

- -
-

- Force Align Words -

-

- Improve word level timing -

+ + {platform() === 'windows' && ( +
+ +
+

+ Force Align Words +

+

+ Improve word level timing +

+
+ setAlignWords(checked)} />
- setAlignWords(checked)} /> -
+ )} diff --git a/Transcription-Server/server.py b/Transcription-Server/server.py index 2098563..3d25ce1 100644 --- a/Transcription-Server/server.py +++ b/Transcription-Server/server.py @@ -71,6 +71,7 @@ def __getattr__(self, attr): from huggingface_hub import HfApi, HfFolder, login, snapshot_download from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError +import torch if getattr(sys, 'frozen', False): base_path = sys._MEIPASS @@ -106,12 +107,12 @@ def __getattr__(self, attr): "base": "base", "small": "small", "medium": "medium", - "large": "large-v3", + "large": "large-v3-turbo", "tiny.en": "tiny.en", "base.en": "base.en", "small.en": "small.en", "medium.en": "medium.en", - "large.en": "large-v3", + "large.en": "large-v3-turbo", } mac_models = { @@ -119,14 +120,21 @@ def __getattr__(self, attr): "base": "mlx-community/whisper-base-mlx-q4", "small": "mlx-community/whisper-small-mlx", "medium": "mlx-community/whisper-medium-mlx", - "large": "mlx-community/distil-whisper-large-v3", + "large": "mlx-community/whisper-large-v3-turbo", "tiny.en": "mlx-community/whisper-tiny.en-mlx", "base.en": "mlx-community/whisper-base.en-mlx", "small.en": "mlx-community/whisper-small.en-mlx", "medium.en": "mlx-community/whisper-medium.en-mlx", - "large.en": "mlx-community/distil-whisper-large-v3", + "large.en": "mlx-community/whisper-large-v3-turbo", + "large.de": "mlx-community/whisper-large-v3-turbo-german-f16", } +def sanitize_result(result): + # Convert the result to a JSON string + result_json = json.dumps(result, default=lambda o: None) + # Parse the JSON string back to a dictionary + sanitized_result = json.loads(result_json) + return sanitized_result def is_model_cached_locally(model_id, revision=None): try: @@ -227,7 +235,6 @@ def transcribe_audio(audio_file, kwargs, max_words, max_chars, sensitive_words): else: result = stable_whisper.transcribe_any( inference, audio_file, inference_kwargs=kwargs, vad=False, regroup=True) - result.pad() result = modify_result(result, max_words, max_chars, sensitive_words) @@ -429,18 +436,6 @@ class TranscriptionRequest(BaseModel): async def transcribe(request: TranscriptionRequest): try: start_time = time.time() - if request.language == "en": - request.model = request.model + ".en" - task = "transcribe" - else: - task = request.task - - if platform.system() == 'Windows': - model = win_models[request.model] - else: - model = mac_models[request.model] - - print(model) file_path = request.file_path timeline = request.timeline @@ -457,15 +452,6 @@ async def transcribe(request: TranscriptionRequest): else: print(f"Processing file: {file_path}") - import torch - kwargs = { - "model": model, - "task": task, - "language": request.language, - "align_words": request.align_words, - "device": "cuda" if torch.cuda.is_available() else "cpu" - } - # Select device if torch.cuda.is_available(): device = torch.device("cuda") @@ -476,6 +462,29 @@ async def transcribe(request: TranscriptionRequest): print(f"Using device: {device}") + if request.language == "en": + request.model = request.model + ".en" + task = "transcribe" + elif request.language == "de" and request.model == "large" and platform.system() != 'Windows': + request.model = request.model + ".de" + else: + task = request.task + + if platform.system() == 'Windows': + model = win_models[request.model] + else: + model = mac_models[request.model] + + print(model) + + kwargs = { + "model": model, + "task": task, + "language": request.language, + "align_words": request.align_words, + "device": "cuda" if torch.cuda.is_available() else "cpu" + } + # Process audio (transcription and optionally diarization) try: result = await process_audio( @@ -498,7 +507,7 @@ async def transcribe(request: TranscriptionRequest): # Save the transcription to a JSON file with open(json_filepath, 'w', encoding='utf-8') as f: - json.dump(result, f, indent=4, ensure_ascii=False) + json.dump(sanitize_result(result), f, indent=4, ensure_ascii=False) print(f"Transcription saved to: {json_filepath}") except Exception as e: