-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
test.py
63 lines (52 loc) · 2.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from timeit import default_timer as timer
import argparse
parser = argparse.ArgumentParser(description="Running Whisper (original) test inference.")
parser.add_argument("-f", "--folder", default="../test-files/", help="Folder with WAV input files")
parser.add_argument("-m", "--model", default="tiny", help="Model name")
parser.add_argument("-l", "--lang", default="en", help="Language used (default: en)")
parser.add_argument("-b", "--beamsize", default=1, help="Beam size used (default: 1)")
args = parser.parse_args()
print("Importing whisper")
import whisper
model_name = args.model
beam_size = int(args.beamsize)
print(f"Loading model {model_name}")
model = whisper.load_model(model_name)
print("Threads: max (this version always tries to use all available threads)")
def transcribe(audio_file):
inference_start = timer()
# load audio and pad/trim it to fit 30 seconds
print(f"\nLoading audio {audio_file} ...")
audio = whisper.load_audio(audio_file)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
print("Calculating mel spectrogram ...")
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect spoken language
auto_lang = None
if ".en" in model_name:
auto_lang = "en"
print("Language fixed to 'en'")
elif args.lang == "auto":
print("Detect language ...")
_, probs = model.detect_language(mel)
auto_lang = max(probs, key=probs.get)
print(f"Detected language: {auto_lang}")
# decode audio
print("Decode audio ...")
if auto_lang is not None:
options = whisper.DecodingOptions(fp16 = False, language=auto_lang, beam_size=beam_size)
else:
options = whisper.DecodingOptions(fp16 = False, language=args.lang, beam_size=beam_size)
#print(options)
#options: task='transcribe', language=None, temperature=0.0, sample_len=None, best_of=None, beam_size=None, patience=None, length_penalty=None, prompt=None, prefix=None, suppress_blank=True, suppress_to>
result = whisper.decode(model, mel, options)
# print the recognized text
print("Result:")
print(result.text)
print("\nInference took {:.2f}s.".format(timer() - inference_start))
test_files = os.listdir(args.folder)
for file in test_files:
if file.endswith(".wav"):
transcribe(args.folder + file)