From 449c106462ab9872726f340bc90091f4a6a9d7d7 Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:42:31 +0800 Subject: [PATCH] support/optimize ASR on HPU (#280) * optimize asr on hpu Signed-off-by: Spycsh --- AudioQnA/audio/docker/asr/asr.py | 61 ++++++++++++++++--------- AudioQnA/audio/docker/asr/asr_server.py | 2 +- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/AudioQnA/audio/docker/asr/asr.py b/AudioQnA/audio/docker/asr/asr.py index e5e74edee0..17b4f456c7 100644 --- a/AudioQnA/audio/docker/asr/asr.py +++ b/AudioQnA/audio/docker/asr/asr.py @@ -8,6 +8,7 @@ import contextlib import os import time +import urllib.request import numpy as np import torch @@ -19,7 +20,13 @@ class AudioSpeechRecognition: """Convert audio to text.""" - def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language=None, device="cpu"): + def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language="english", device="cpu"): + if device == "hpu": + # Explicitly link HPU with Torch + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + adapt_transformers_to_gaudi() + self.device = device asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path) print("Downloading model: {}".format(asr_model_name_or_path)) @@ -33,6 +40,12 @@ def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, langua self.model = ipex.optimize(self.model, dtype=torch.bfloat16) self.language = language + if device == "hpu": + # do hpu graph warmup with a long enough input audio + # whisper has a receptive field of 30 seconds + # here we select a relatively long audio (~15 sec) to quickly warmup + self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav") + def _audiosegment_to_librosawav(self, audiosegment): # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples # This way is faster than librosa.load or HuggingFace Dataset wrapper @@ -45,16 +58,27 @@ def _audiosegment_to_librosawav(self, audiosegment): return fp_arr + def _warmup_whisper_hpu_graph(self, url): + print("[ASR] fetch warmup audio...") + urllib.request.urlretrieve( + url, + "warmup.wav", + ) + print("[ASR] warmup...") + waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000) + waveform = self._audiosegment_to_librosawav(waveform) + # pylint: disable=E1101 + inputs = self.processor.feature_extractor( + waveform, return_tensors="pt", sampling_rate=16_000 + ).input_features.to(self.device) + _ = self.model.generate(inputs, language="chinese") + def audio2text(self, audio_path): """Convert audio to text. audio_path: the path to the input audio, e.g. ~/xxx.mp3 """ start = time.time() - if audio_path.split(".")[-1] in ["flac", "ogg", "aac", "m4a"]: - audio_path = self._convert_audio_type(audio_path) - elif audio_path.split(".")[-1] not in ["mp3", "wav"]: - raise Exception("[ASR ERROR] Audio format not supported!") try: waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000) @@ -69,20 +93,10 @@ def audio2text(self, audio_path): waveform, return_tensors="pt", sampling_rate=16_000 ).input_features.to(self.device) with torch.cpu.amp.autocast() if self.bf16 else contextlib.nullcontext(): - if self.language is None: - predicted_ids = self.model.generate(inputs) - elif self.language == "auto": - self.model.config.forced_decoder_ids = None - predicted_ids = self.model.generate(inputs) - else: - self.forced_decoder_ids = self.processor.get_decoder_prompt_ids( - language=self.language, task="transcribe" - ) - self.model.config.forced_decoder_ids = self.forced_decoder_ids - predicted_ids = self.model.generate(inputs) + predicted_ids = self.model.generate(inputs, language=self.language) # pylint: disable=E1101 result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0] - if self.language == "auto" or self.language == "zh": + if self.language in ["chinese", "mandarin"]: from zhconv import convert result = convert(result, "zh-cn") @@ -91,15 +105,20 @@ def audio2text(self, audio_path): if __name__ == "__main__": - asr = AudioSpeechRecognition(language="auto") - import urllib.request + asr = AudioSpeechRecognition(language="english") + + # Test multilanguage asr + urllib.request.urlretrieve( + "https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav", + "sample.wav", + ) + asr.language = "chinese" + text = asr.audio2text("sample.wav") urllib.request.urlretrieve( "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav", "sample.wav", ) text = asr.audio2text("sample.wav") - import os os.remove("sample.wav") - print(text) diff --git a/AudioQnA/audio/docker/asr/asr_server.py b/AudioQnA/audio/docker/asr/asr_server.py index a2c4ec6655..4eadb1c8e9 100644 --- a/AudioQnA/audio/docker/asr/asr_server.py +++ b/AudioQnA/audio/docker/asr/asr_server.py @@ -58,7 +58,7 @@ async def audio_to_text(file: UploadFile = File(...)): parser.add_argument("--port", type=int, default=8008) parser.add_argument("--model_name_or_path", type=str, default="openai/whisper-tiny") parser.add_argument("--bf16", default=False, action="store_true") - parser.add_argument("--language", type=str, default="auto") + parser.add_argument("--language", type=str, default="english") parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args()