-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
33 lines (27 loc) · 920 Bytes
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from interface import ProgressTranscriber
from audio import AudioFile
from typing import TYPE_CHECKING, Union
from whisper import load_model
import numpy as np
import torch
if TYPE_CHECKING:
from whisper.model import Whisper
class InMemoryAudio(AudioFile):
dft_pad = True
def audio_tensor(audio: Union[str, np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(audio, str):
return InMemoryAudio(fname=audio).sequential()
if isinstance(audio, np.ndarray):
return torch.from_numpy(audio)
return audio
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
**kw):
return ProgressTranscriber(model, **kw)(audio_tensor(audio))
if __name__ == "__main__":
# import sys
# print(transcribe(load_model("base.en"), sys.argv[1]))
from whisper.transcribe import cli
cli.__globals__["transcribe"] = transcribe
cli()