Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support/optimize ASR on HPU #280

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions AudioQnA/audio/docker/asr/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import contextlib
import os
import time
import urllib.request

import numpy as np
import torch
Expand All @@ -19,7 +20,13 @@
class AudioSpeechRecognition:
"""Convert audio to text."""

def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language=None, device="cpu"):
def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language="english", device="cpu"):
if device == "hpu":
# Explicitly link HPU with Torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

adapt_transformers_to_gaudi()

self.device = device
asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
print("Downloading model: {}".format(asr_model_name_or_path))
Expand All @@ -33,6 +40,12 @@ def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, langua
self.model = ipex.optimize(self.model, dtype=torch.bfloat16)
self.language = language

if device == "hpu":
# do hpu graph warmup with a long enough input audio
# whisper has a receptive field of 30 seconds
# here we select a relatively long audio (~15 sec) to quickly warmup
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")

def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
# This way is faster than librosa.load or HuggingFace Dataset wrapper
Expand All @@ -45,16 +58,27 @@ def _audiosegment_to_librosawav(self, audiosegment):

return fp_arr

def _warmup_whisper_hpu_graph(self, url):
print("[ASR] fetch warmup audio...")
urllib.request.urlretrieve(
url,
"warmup.wav",
)
print("[ASR] warmup...")
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
waveform = self._audiosegment_to_librosawav(waveform)
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
_ = self.model.generate(inputs, language="chinese")

def audio2text(self, audio_path):
"""Convert audio to text.

audio_path: the path to the input audio, e.g. ~/xxx.mp3
"""
start = time.time()
if audio_path.split(".")[-1] in ["flac", "ogg", "aac", "m4a"]:
audio_path = self._convert_audio_type(audio_path)
elif audio_path.split(".")[-1] not in ["mp3", "wav"]:
raise Exception("[ASR ERROR] Audio format not supported!")

try:
waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000)
Expand All @@ -69,20 +93,10 @@ def audio2text(self, audio_path):
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
with torch.cpu.amp.autocast() if self.bf16 else contextlib.nullcontext():
if self.language is None:
predicted_ids = self.model.generate(inputs)
elif self.language == "auto":
self.model.config.forced_decoder_ids = None
predicted_ids = self.model.generate(inputs)
else:
self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=self.language, task="transcribe"
)
self.model.config.forced_decoder_ids = self.forced_decoder_ids
predicted_ids = self.model.generate(inputs)
predicted_ids = self.model.generate(inputs, language=self.language)
# pylint: disable=E1101
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
if self.language == "auto" or self.language == "zh":
if self.language in ["chinese", "mandarin"]:
from zhconv import convert

result = convert(result, "zh-cn")
Expand All @@ -91,15 +105,20 @@ def audio2text(self, audio_path):


if __name__ == "__main__":
asr = AudioSpeechRecognition(language="auto")
import urllib.request
asr = AudioSpeechRecognition(language="english")

# Test multilanguage asr
urllib.request.urlretrieve(
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
"sample.wav",
)
asr.language = "chinese"
text = asr.audio2text("sample.wav")

urllib.request.urlretrieve(
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
"sample.wav",
)
text = asr.audio2text("sample.wav")
import os

os.remove("sample.wav")
print(text)
2 changes: 1 addition & 1 deletion AudioQnA/audio/docker/asr/asr_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def audio_to_text(file: UploadFile = File(...)):
parser.add_argument("--port", type=int, default=8008)
parser.add_argument("--model_name_or_path", type=str, default="openai/whisper-tiny")
parser.add_argument("--bf16", default=False, action="store_true")
parser.add_argument("--language", type=str, default="auto")
parser.add_argument("--language", type=str, default="english")
parser.add_argument("--device", type=str, default="cpu")

args = parser.parse_args()
Expand Down
Loading