Skip to content

Commit

Permalink
Merge pull request #175 from jhj0517/feature/integrate-insanely_fast_…
Browse files Browse the repository at this point in the history
…whisper

Integrate with insanely fast whisper
  • Loading branch information
jhj0517 authored Jun 24, 2024
2 parents 2457c38 + 661e83c commit 89df94c
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 93 deletions.
197 changes: 111 additions & 86 deletions app.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion modules/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def transcribe(self,
"""
start_time = time.time()

params = WhisperValues(*whisper_params)
params = WhisperParameters.post_process(*whisper_params)

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
Expand Down
181 changes: 181 additions & 0 deletions modules/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import os
import time
import numpy as np
from typing import BinaryIO, Union, Tuple, List
import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import gradio as gr
from huggingface_hub import hf_hub_download
import whisper
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn

from modules.whisper_parameter import *
from modules.whisper_base import WhisperBase


class InsanelyFastWhisperInference(WhisperBase):
def __init__(self):
super().__init__(
model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper")
)
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
self.available_models = openai_models + distil_models
self.available_compute_types = ["float16"]

def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress,
*whisper_params,
) -> Tuple[List[dict], float]:
"""
transcribe method for faster-whisper.
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio path or file binary or Audio numpy array
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Gradio components related to Whisper. see whisper_data_class.py for details.
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParameters.post_process(*whisper_params)

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)

if params.lang == "Automatic Detection":
params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.lang = language_code_dict[params.lang]

progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(style="yellow1", pulse_style="white"),
TimeElapsedColumn(),
) as progress:
progress.add_task("[yellow]Transcribing...", total=None)

segments = self.model(
inputs=audio,
return_timestamps=True,
chunk_length_s=params.chunk_length_s,
batch_size=params.batch_size,
generate_kwargs={
"language": params.lang,
"task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
"no_speech_threshold": params.no_speech_threshold,
"temperature": params.temperature,
"compression_ratio_threshold": params.compression_ratio_threshold
}
)

segments_result = self.format_result(
transcribed_result=segments,
)
elapsed_time = time.time() - start_time
return segments_result, elapsed_time

def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress,
):
"""
Update current model setting
Parameters
----------
model_size: str
Size of whisper model
compute_type: str
Compute type for transcription.
see more info : https://opennmt.net/CTranslate2/quantization.html
progress: gr.Progress
Indicator to show progress directly in gradio.
"""
progress(0, desc="Initializing Model..")
model_path = os.path.join(self.model_dir, model_size)
if not os.path.isdir(model_path) or not os.listdir(model_path):
self.download_model(
model_size=model_size,
download_root=model_path,
progress=progress
)

self.current_compute_type = compute_type
self.current_model_size = model_size
self.model = pipeline(
"automatic-speech-recognition",
model=os.path.join(self.model_dir, model_size),
torch_dtype=self.current_compute_type,
device=self.device,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)

@staticmethod
def format_result(
transcribed_result: dict
) -> List[dict]:
"""
Format the transcription result of insanely_fast_whisper as the same with other implementation.
Parameters
----------
transcribed_result: dict
Transcription result of the insanely_fast_whisper
Returns
----------
result: List[dict]
Formatted result as the same with other implementation
"""
result = transcribed_result["chunks"]
for item in result:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
item["start"] = start
item["end"] = end
return result

@staticmethod
def download_model(
model_size: str,
download_root: str,
progress: gr.Progress
):
progress(0, 'Initializing model..')
print(f'Downloading {model_size} to "{download_root}"....')

os.makedirs(download_root, exist_ok=True)
download_list = [
"model.safetensors",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"added_tokens.json",
"special_tokens_map.json",
"vocab.json",
]

if model_size.startswith("distil"):
repo_id = f"distil-whisper/{model_size}"
else:
repo_id = f"openai/whisper-{model_size}"
for item in download_list:
hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)
2 changes: 1 addition & 1 deletion modules/whisper_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transcribe(self,
elapsed time for transcription
"""
start_time = time.time()
params = WhisperValues(*whisper_params)
params = WhisperParameters.post_process(*whisper_params)

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
Expand Down
56 changes: 52 additions & 4 deletions modules/whisper_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@dataclass
class WhisperGradioComponents:
class WhisperParameters:
model_size: gr.Dropdown
lang: gr.Dropdown
is_translate: gr.Checkbox
Expand All @@ -25,8 +25,12 @@ class WhisperGradioComponents:
min_silence_duration_ms: gr.Number
window_size_sample: gr.Number
speech_pad_ms: gr.Number
chunk_length_s: gr.Number
batch_size: gr.Number
"""
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
See more about Gradio pre-processing: https://www.gradio.app/docs/components
Attributes
Expand Down Expand Up @@ -111,11 +115,18 @@ class WhisperGradioComponents:
speech_pad_ms: gr.Number
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
chunk_length_s: gr.Number
This parameter is related with insanely-fast-whisper pipe.
Maximum length of each chunk
batch_size: gr.Number
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
"""

def to_list(self) -> list:
"""
Converts the data class attributes into a list. Use "before" Gradio pre-processing.
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
Returns
Expand All @@ -124,6 +135,42 @@ def to_list(self) -> list:
"""
return [getattr(self, f.name) for f in fields(self)]

@staticmethod
def post_process(*args) -> 'WhisperValues':
"""
To use Whisper parameters in function after Gradio post-processing.
See more about Gradio post-processing: : https://www.gradio.app/docs/components
Returns
----------
WhisperValues
Data class that has values of parameters
"""
return WhisperValues(
model_size=args[0],
lang=args[1],
is_translate=args[2],
beam_size=args[3],
log_prob_threshold=args[4],
no_speech_threshold=args[5],
compute_type=args[6],
best_of=args[7],
patience=args[8],
condition_on_previous_text=args[9],
initial_prompt=args[10],
temperature=args[11],
compression_ratio_threshold=args[12],
vad_filter=args[13],
threshold=args[14],
min_speech_duration_ms=args[15],
max_speech_duration_s=args[16],
min_silence_duration_ms=args[17],
window_size_samples=args[18],
speech_pad_ms=args[19],
chunk_length_s=args[20],
batch_size=args[21]
)


@dataclass
class WhisperValues:
Expand All @@ -147,7 +194,8 @@ class WhisperValues:
min_silence_duration_ms: int
window_size_samples: int
speech_pad_ms: int
chunk_length_s: int
batch_size: int
"""
A data class to use Whisper parameters. Use "after" Gradio pre-processing.
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
A data class to use Whisper parameters.
"""
6 changes: 5 additions & 1 deletion user-start-webui.bat
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set API_OPEN=
set WHISPER_TYPE=
set WHISPER_MODEL_DIR=
set FASTER_WHISPER_MODEL_DIR=
set INSANELY_FAST_WHISPER_MODEL_DIR=


if not "%SERVER_NAME%"=="" (
Expand Down Expand Up @@ -47,7 +48,10 @@ if not "%WHISPER_MODEL_DIR%"=="" (
if not "%FASTER_WHISPER_MODEL_DIR%"=="" (
set FASTER_WHISPER_MODEL_DIR_ARG=--faster_whisper_model_dir "%FASTER_WHISPER_MODEL_DIR%"
)
if not "%INSANELY_FAST_WHISPER_MODEL_DIR%"=="" (
set INSANELY_FAST_WHISPER_MODEL_DIR_ARG=--insanely_fast_whisper_model_dir "%INSANELY_FAST_WHISPER_MODEL_DIR%"
)

:: Call the original .bat script with optional arguments
start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %API_OPEN% %WHISPER_TYPE_ARG% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG%
start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %API_OPEN% %WHISPER_TYPE_ARG% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG% %INSANELY_FAST_WHISPER_MODEL_DIR_ARG%
pause

0 comments on commit 89df94c

Please sign in to comment.