Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export of Openai Whisper with batched prompts #19854

Merged
merged 9 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,13 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_dummy(
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
model_impl: str = "hf",
): # -> WhisperDecoderInputs:
"""Create dummy inputs for WhisperDecoder.
Expand Down Expand Up @@ -170,7 +171,7 @@ def create_dummy(
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
head_size,
]

Expand Down Expand Up @@ -228,6 +229,7 @@ def export_onnx(
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
model_impl=decoder.model_impl,
)
input_list = inputs.to_list()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# --------------------------------------------------------------------------

import copy
import logging
import os
import tempfile
Expand Down Expand Up @@ -51,12 +50,15 @@ def forward(
self,
encoder_input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor = None,
remove_hooks: bool = False,
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
decinit_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks
)
return decinit_out, encoder_hidden_states, present
else:
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
Expand Down Expand Up @@ -132,8 +134,7 @@ def export_onnx(
input_list = inputs.to_list()

# TODO : Investigate whether copy of model if needed
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
cloned_model = copy.deepcopy(model).to(device)
out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids)
out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True)
present = out[2]
present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)

Expand Down
192 changes: 140 additions & 52 deletions onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,13 @@ def optimize_onnx(
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
def pt_transcription_for_verify_onnx(
processor: WhisperProcessor,
pt_model: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)

# Try to import `datasets` pip package
try:
from datasets import load_dataset
Expand All @@ -342,14 +333,18 @@ def verify_onnx(
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1
input_features_ = []
if batch_size == 1:
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
else:
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
assert len(input_features_) == batch_size
input_features = torch.cat((input_features_[0], input_features_[1]))

max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
Expand All @@ -362,10 +357,97 @@ def verify_onnx(
"early_stopping": True,
"use_cache": True,
}
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()

if prompt_mode:
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
pt_outputs = []
# The looping for model.generate is necessary here due to the limitation as per
# https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
# prompt_ids input requires a tensor of rank 1
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
pt_outputs.append(pt_output)
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
else:
prompt_ids = []
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
pt_outputs = list(pt_outputs)
del inputs["early_stopping"]
del inputs["use_cache"]
return inputs, pt_transcription, pt_outputs, prompt_ids

@staticmethod
def select_transcription_options(
batch_size: int,
prompt_mode: bool,
):
if batch_size > 1 and prompt_mode is True:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
else:
expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
return expected_transcription_options

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
config = WhisperConfig.from_pretrained(model_name_or_path)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
processor,
pt_model,
device,
batch_size=batch_size,
prompt_mode=prompt_mode,
)

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
Expand All @@ -386,8 +468,24 @@ def verify_onnx(
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
if not prompt_mode:
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
else:
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(decoder_prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
Expand All @@ -398,36 +496,26 @@ def verify_onnx(
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]

expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]

parity = (
pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options
)
ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)

parity = 1
for i in range(batch_size):
parity *= (
pt_transcription[i] in expected_transcription_options
and ort_transcription[i] in expected_transcription_options
)
max_diff = 0

if not parity:
if pt_outputs.shape != ort_outputs.shape:
diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
else:
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)
for i in range(batch_size):
if pt_outputs[i].shape != ort_outputs[i].shape:
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
else:
diff = pt_outputs[i] - ort_outputs[i]
max_diff_i = max(diff.min(), diff.max(), key=abs)
max_diff = max(max_diff, max_diff_i)

if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(
tokens,
audio_features,
past=None,
remove_hooks=False,
):
# Create a kv_cache for past_values
past_kv_cache = dict()
Expand All @@ -44,8 +45,9 @@ def forward(
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]

hooks = None
if not self.kv_cache:
self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks()
self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks()

logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)

Expand Down Expand Up @@ -73,4 +75,10 @@ def forward(
present_self = [
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
]

# Remove forward hooks to avoid model cloning step
if hooks is not None and remove_hooks is True:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
self.kv_cache = {}
for hook in hooks:
hook.remove()
return logits, present_self
Loading