Skip to content

Commit

Permalink
support/optimize ASR on HPU (#280)
Browse files Browse the repository at this point in the history
* optimize asr on hpu

Signed-off-by: Spycsh <[email protected]>
  • Loading branch information
Spycsh authored Jun 12, 2024
1 parent 2405879 commit 2a48601
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
61 changes: 40 additions & 21 deletions AudioQnA/audio/docker/asr/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import contextlib
import os
import time
import urllib.request

import numpy as np
import torch
Expand All @@ -19,7 +20,13 @@
class AudioSpeechRecognition:
"""Convert audio to text."""

def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language=None, device="cpu"):
def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, language="english", device="cpu"):
if device == "hpu":
# Explicitly link HPU with Torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

adapt_transformers_to_gaudi()

self.device = device
asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
print("Downloading model: {}".format(asr_model_name_or_path))
Expand All @@ -33,6 +40,12 @@ def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, langua
self.model = ipex.optimize(self.model, dtype=torch.bfloat16)
self.language = language

if device == "hpu":
# do hpu graph warmup with a long enough input audio
# whisper has a receptive field of 30 seconds
# here we select a relatively long audio (~15 sec) to quickly warmup
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")

def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
# This way is faster than librosa.load or HuggingFace Dataset wrapper
Expand All @@ -45,16 +58,27 @@ def _audiosegment_to_librosawav(self, audiosegment):

return fp_arr

def _warmup_whisper_hpu_graph(self, url):
print("[ASR] fetch warmup audio...")
urllib.request.urlretrieve(
url,
"warmup.wav",
)
print("[ASR] warmup...")
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
waveform = self._audiosegment_to_librosawav(waveform)
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
_ = self.model.generate(inputs, language="chinese")

def audio2text(self, audio_path):
"""Convert audio to text.
audio_path: the path to the input audio, e.g. ~/xxx.mp3
"""
start = time.time()
if audio_path.split(".")[-1] in ["flac", "ogg", "aac", "m4a"]:
audio_path = self._convert_audio_type(audio_path)
elif audio_path.split(".")[-1] not in ["mp3", "wav"]:
raise Exception("[ASR ERROR] Audio format not supported!")

try:
waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000)
Expand All @@ -69,20 +93,10 @@ def audio2text(self, audio_path):
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
with torch.cpu.amp.autocast() if self.bf16 else contextlib.nullcontext():
if self.language is None:
predicted_ids = self.model.generate(inputs)
elif self.language == "auto":
self.model.config.forced_decoder_ids = None
predicted_ids = self.model.generate(inputs)
else:
self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=self.language, task="transcribe"
)
self.model.config.forced_decoder_ids = self.forced_decoder_ids
predicted_ids = self.model.generate(inputs)
predicted_ids = self.model.generate(inputs, language=self.language)
# pylint: disable=E1101
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
if self.language == "auto" or self.language == "zh":
if self.language in ["chinese", "mandarin"]:
from zhconv import convert

result = convert(result, "zh-cn")
Expand All @@ -91,15 +105,20 @@ def audio2text(self, audio_path):


if __name__ == "__main__":
asr = AudioSpeechRecognition(language="auto")
import urllib.request
asr = AudioSpeechRecognition(language="english")

# Test multilanguage asr
urllib.request.urlretrieve(
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
"sample.wav",
)
asr.language = "chinese"
text = asr.audio2text("sample.wav")

urllib.request.urlretrieve(
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
"sample.wav",
)
text = asr.audio2text("sample.wav")
import os

os.remove("sample.wav")
print(text)
2 changes: 1 addition & 1 deletion AudioQnA/audio/docker/asr/asr_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def audio_to_text(file: UploadFile = File(...)):
parser.add_argument("--port", type=int, default=8008)
parser.add_argument("--model_name_or_path", type=str, default="openai/whisper-tiny")
parser.add_argument("--bf16", default=False, action="store_true")
parser.add_argument("--language", type=str, default="auto")
parser.add_argument("--language", type=str, default="english")
parser.add_argument("--device", type=str, default="cpu")

args = parser.parse_args()
Expand Down

0 comments on commit 2a48601

Please sign in to comment.