forked from sanchit-gandhi/whisper-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
267 lines (223 loc) · 11.6 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import logging
import math
import os
import tempfile
import time
from multiprocessing import Pool
import gradio as gr
import jax.numpy as jnp
import numpy as np
import yt_dlp as youtube_dl
from jax.experimental.compilation_cache import compilation_cache as cc
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read
from whisper_jax import FlaxWhisperPipline
cc.initialize_cache("./jax_cache")
checkpoint = "openai/whisper-large-v2"
BATCH_SIZE = 32
CHUNK_LENGTH_S = 30
NUM_PROC = 32
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 7200 # limit to 2 hour YouTube files
title = "Whisper JAX: The Fastest Whisper API ⚡️"
description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available.
Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar.
To skip the queue, you may wish to create your own inference endpoint, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint).
"""
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
language_names = sorted(TO_LANGUAGE_CODE.keys())
logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)
def identity(batch):
return batch
# Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
if seconds is not None:
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
else:
# we have a malformed timestamp so just return it as is
return seconds
if __name__ == "__main__":
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
stride_length_s = CHUNK_LENGTH_S / 6
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
step = chunk_len - stride_left - stride_right
pool = Pool(NUM_PROC)
# do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
logger.info("compiling forward call...")
start = time.time()
random_inputs = {
"input_features": np.ones(
(BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
)
}
random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
compile_time = time.time() - start
logger.info(f"compiled in {compile_time}s")
def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
inputs_len = inputs["array"].shape[0]
all_chunk_start_idx = np.arange(0, inputs_len, step)
num_samples = len(all_chunk_start_idx)
num_batches = math.ceil(num_samples / BATCH_SIZE)
dummy_batches = list(
range(num_batches)
) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
progress(0, desc="Pre-processing audio file...")
logger.info("pre-processing audio file...")
dataloader = pool.map(identity, dataloader)
logger.info("done post-processing")
model_outputs = []
start_time = time.time()
logger.info("transcribing...")
# iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
runtime = time.time() - start_time
logger.info("done transcription")
logger.info("post-processing...")
post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
text = post_processed["text"]
if return_timestamps:
timestamps = post_processed.get("chunks")
timestamps = [
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
for chunk in timestamps
]
text = "\n".join(str(feature) for feature in timestamps)
logger.info("done post-processing")
return text, runtime
def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
progress(0, desc="Loading audio file...")
logger.info("loading audio file...")
if inputs is None:
logger.warning("No audio file")
raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
if file_size_mb > FILE_LIMIT_MB:
logger.warning("Max file size exceeded")
raise gr.Error(
f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
)
with open(inputs, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
logger.info("done loading")
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
return text, runtime
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
" </center>"
)
return HTML_str
def download_yt_audio(yt_url, filename):
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
raise gr.Error(str(err))
file_length = info["duration_string"]
file_h_m_s = file_length.split(":")
file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
if len(file_h_m_s) == 1:
file_h_m_s.insert(0, 0)
if len(file_h_m_s) == 2:
file_h_m_s.insert(0, 0)
file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
if file_length_s > YT_LENGTH_LIMIT_S:
yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
try:
ydl.download([yt_url])
except youtube_dl.utils.ExtractorError as err:
raise gr.Error(str(err))
def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress()):
progress(0, desc="Loading audio file...")
logger.info("loading youtube file...")
html_embed_str = _return_yt_html_embed(yt_url)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "video.mp4")
download_yt_audio(yt_url, filepath)
with open(filepath, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
logger.info("done loading...")
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
return html_embed_str, text, runtime
microphone_chunked = gr.Interface(
fn=transcribe_chunked_audio,
inputs=[
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
gr.inputs.Checkbox(default=False, label="Return timestamps"),
],
outputs=[
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
gr.outputs.Textbox(label="Transcription Time (s)"),
],
allow_flagging="never",
title=title,
description=description,
article=article,
)
audio_chunked = gr.Interface(
fn=transcribe_chunked_audio,
inputs=[
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
gr.inputs.Checkbox(default=False, label="Return timestamps"),
],
outputs=[
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
gr.outputs.Textbox(label="Transcription Time (s)"),
],
allow_flagging="never",
title=title,
description=description,
article=article,
)
youtube = gr.Interface(
fn=transcribe_youtube,
inputs=[
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
gr.inputs.Checkbox(default=False, label="Return timestamps"),
],
outputs=[
gr.outputs.HTML(label="Video"),
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
gr.outputs.Textbox(label="Transcription Time (s)"),
],
allow_flagging="never",
title=title,
examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]],
cache_examples=False,
description=description,
article=article,
)
demo = gr.Blocks()
with demo:
gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
demo.queue(concurrency_count=1, max_size=5)
demo.launch(server_name="0.0.0.0", show_api=False)