diff --git a/AutoSubs-App/src/GlobalContext.tsx b/AutoSubs-App/src/GlobalContext.tsx
index fd49d84..fc427c2 100644
--- a/AutoSubs-App/src/GlobalContext.tsx
+++ b/AutoSubs-App/src/GlobalContext.tsx
@@ -234,7 +234,7 @@ export function GlobalProvider({ children }: React.PropsWithChildren<{}>) {
} else {
console.log(`Transcription Server INFO: "${line}"`);
}
- } else if (line.includes("model.bin:")) {
+ } else if (line.includes("model.bin:") || line.includes("weights.safetensors:")) {
const percentageMatch = line.match(/(\d+)%/);
if (percentageMatch && percentageMatch[1]) {
const percentage = parseInt(percentageMatch[1], 10);
diff --git a/AutoSubs-App/src/pages/home-page.tsx b/AutoSubs-App/src/pages/home-page.tsx
index d8c9e02..9c838b4 100644
--- a/AutoSubs-App/src/pages/home-page.tsx
+++ b/AutoSubs-App/src/pages/home-page.tsx
@@ -1,4 +1,5 @@
-import { useState } from "react"
+import { useState } from "react";
+import { platform } from '@tauri-apps/plugin-os';
import { cn } from "@/lib/utils"
import {
Bird,
@@ -605,7 +606,7 @@ export function HomePage() {
Whisper{" "}
- Large-V3
+ Large
@@ -683,18 +684,21 @@ export function HomePage() {
setMaxChars(Math.abs(Number.parseInt(e.target.value)))} />
-
-
-
-
- Force Align Words
-
-
- Improve word level timing
-
+
+ {platform() === 'windows' && (
+
+
+
+
+ Force Align Words
+
+
+ Improve word level timing
+
+
+
setAlignWords(checked)} />
-
setAlignWords(checked)} />
-
+ )}
diff --git a/Transcription-Server/server.py b/Transcription-Server/server.py
index 2098563..3d25ce1 100644
--- a/Transcription-Server/server.py
+++ b/Transcription-Server/server.py
@@ -71,6 +71,7 @@ def __getattr__(self, attr):
from huggingface_hub import HfApi, HfFolder, login, snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
+import torch
if getattr(sys, 'frozen', False):
base_path = sys._MEIPASS
@@ -106,12 +107,12 @@ def __getattr__(self, attr):
"base": "base",
"small": "small",
"medium": "medium",
- "large": "large-v3",
+ "large": "large-v3-turbo",
"tiny.en": "tiny.en",
"base.en": "base.en",
"small.en": "small.en",
"medium.en": "medium.en",
- "large.en": "large-v3",
+ "large.en": "large-v3-turbo",
}
mac_models = {
@@ -119,14 +120,21 @@ def __getattr__(self, attr):
"base": "mlx-community/whisper-base-mlx-q4",
"small": "mlx-community/whisper-small-mlx",
"medium": "mlx-community/whisper-medium-mlx",
- "large": "mlx-community/distil-whisper-large-v3",
+ "large": "mlx-community/whisper-large-v3-turbo",
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
- "large.en": "mlx-community/distil-whisper-large-v3",
+ "large.en": "mlx-community/whisper-large-v3-turbo",
+ "large.de": "mlx-community/whisper-large-v3-turbo-german-f16",
}
+def sanitize_result(result):
+ # Convert the result to a JSON string
+ result_json = json.dumps(result, default=lambda o: None)
+ # Parse the JSON string back to a dictionary
+ sanitized_result = json.loads(result_json)
+ return sanitized_result
def is_model_cached_locally(model_id, revision=None):
try:
@@ -227,7 +235,6 @@ def transcribe_audio(audio_file, kwargs, max_words, max_chars, sensitive_words):
else:
result = stable_whisper.transcribe_any(
inference, audio_file, inference_kwargs=kwargs, vad=False, regroup=True)
- result.pad()
result = modify_result(result, max_words, max_chars, sensitive_words)
@@ -429,18 +436,6 @@ class TranscriptionRequest(BaseModel):
async def transcribe(request: TranscriptionRequest):
try:
start_time = time.time()
- if request.language == "en":
- request.model = request.model + ".en"
- task = "transcribe"
- else:
- task = request.task
-
- if platform.system() == 'Windows':
- model = win_models[request.model]
- else:
- model = mac_models[request.model]
-
- print(model)
file_path = request.file_path
timeline = request.timeline
@@ -457,15 +452,6 @@ async def transcribe(request: TranscriptionRequest):
else:
print(f"Processing file: {file_path}")
- import torch
- kwargs = {
- "model": model,
- "task": task,
- "language": request.language,
- "align_words": request.align_words,
- "device": "cuda" if torch.cuda.is_available() else "cpu"
- }
-
# Select device
if torch.cuda.is_available():
device = torch.device("cuda")
@@ -476,6 +462,29 @@ async def transcribe(request: TranscriptionRequest):
print(f"Using device: {device}")
+ if request.language == "en":
+ request.model = request.model + ".en"
+ task = "transcribe"
+ elif request.language == "de" and request.model == "large" and platform.system() != 'Windows':
+ request.model = request.model + ".de"
+ else:
+ task = request.task
+
+ if platform.system() == 'Windows':
+ model = win_models[request.model]
+ else:
+ model = mac_models[request.model]
+
+ print(model)
+
+ kwargs = {
+ "model": model,
+ "task": task,
+ "language": request.language,
+ "align_words": request.align_words,
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
+ }
+
# Process audio (transcription and optionally diarization)
try:
result = await process_audio(
@@ -498,7 +507,7 @@ async def transcribe(request: TranscriptionRequest):
# Save the transcription to a JSON file
with open(json_filepath, 'w', encoding='utf-8') as f:
- json.dump(result, f, indent=4, ensure_ascii=False)
+ json.dump(sanitize_result(result), f, indent=4, ensure_ascii=False)
print(f"Transcription saved to: {json_filepath}")
except Exception as e: