Skip to content

Commit

Permalink
Merge pull request #165 from jhj0517/feature/enable-finetuning-models
Browse files Browse the repository at this point in the history
Enable fine-tuned faster whisper model
  • Loading branch information
jhj0517 authored Jun 7, 2024
2 parents 6e51e1b + 91dee77 commit acfa385
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions modules/faster_whisper_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os

import time
import numpy as np
from typing import BinaryIO, Union, Tuple, List
Expand All @@ -24,16 +23,17 @@
class FasterWhisperInference(BaseInterface):
def __init__(self):
super().__init__()
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
self.current_model_size = None
self.model = None
self.available_models = whisper.available_models()
self.model_paths = self.get_model_paths()
self.available_models = self.model_paths.keys()
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.available_compute_types = ctranslate2.get_supported_compute_types(
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")

def transcribe_file(self,
files: list,
Expand Down Expand Up @@ -317,15 +317,40 @@ def update_model(self,
Indicator to show progress directly in gradio.
"""
progress(0, desc="Initializing Model..")
self.current_model_size = model_size
self.current_model_size = self.model_paths[model_size]
self.current_compute_type = compute_type
self.model = faster_whisper.WhisperModel(
device=self.device,
model_size_or_path=model_size,
model_size_or_path=self.current_model_size,
download_root=self.model_dir,
compute_type=self.current_compute_type
)

def get_model_paths(self):
"""
Get available models from models path including fine-tuned model.
Returns
----------
Name list of models
"""
model_paths = {model:model for model in whisper.available_models()}
faster_whisper_prefix = "models--Systran--faster-whisper-"

existing_models = os.listdir(self.model_dir)
wrong_dirs = [".locks"]
existing_models = list(set(existing_models) - set(wrong_dirs))

webui_dir = os.getcwd()

for model_name in existing_models:
if faster_whisper_prefix in model_name:
model_name = model_name[len(faster_whisper_prefix):]

if model_name not in whisper.available_models():
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
return model_paths

@staticmethod
def generate_and_write_file(file_name: str,
transcribed_segments: list,
Expand Down

0 comments on commit acfa385

Please sign in to comment.