From 439211cd0294e5e4a34c5ed0dd207aec012b064e Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 11 Mar 2024 18:44:59 +0000 Subject: [PATCH 1/9] Add logic for padding prompts --- .../models/whisper/convert_to_onnx.py | 8 +- .../models/whisper/whisper_decoder.py | 4 +- .../models/whisper/whisper_helper.py | 146 ++++++++++++++++++ 3 files changed, 156 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 35211aab272e4..2d084a3017c2f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -506,7 +506,13 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) + # Verify batched decoding with prompts for whisper openai implementation + if args.model_impl == "openai" and args.use_forced_decoder_ids: + max_diff = WhisperHelper.verify_onnx_multi_batch( + args.model_name_or_path, cache_dir, ort_session, device + ) + else: + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 93fd64c9eb7d3..5da235d72ca0b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -126,6 +126,7 @@ def create_dummy( device: torch.device, float16: bool = False, use_int32_inputs: bool = False, + model_impl: str = "hf", ): # -> WhisperDecoderInputs: """Create dummy inputs for WhisperDecoder. @@ -170,7 +171,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - encode_sequence_length, + encode_sequence_length if model_impl == "hf" else past_decode_sequence_length, head_size, ] @@ -228,6 +229,7 @@ def export_onnx( past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, + model_impl=decoder.model_impl, ) input_list = inputs.to_list() diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 1b47b9426d983..003526c53c6fd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -434,3 +434,149 @@ def verify_onnx( logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff + + @staticmethod + def verify_onnx_multi_batch( + model_name_or_path: str, + cache_dir: str, + ort_session: InferenceSession, + device: torch.device, + ): + """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) + processor = WhisperProcessor.from_pretrained(model_name_or_path) + config = WhisperConfig.from_pretrained(model_name_or_path) + + # Try to import `datasets` pip package + try: + from datasets import load_dataset + except Exception as e: + logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) + install_cmd = "pip install datasets" + logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") + os.system(install_cmd) + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + input_features_ = [ + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + ] + input_features = torch.cat((input_features_[0], input_features_[1])).to(device) + + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 2, 30, 0, 1, 1 + length_penalty, repetition_penalty = 1.0, 1.0 + inputs = { + "input_features": input_features.to(device), + "max_length": max_length, + "min_length": min_length, + "num_beams": num_beams, + "num_return_sequences": num_return_sequences, + "length_penalty": length_penalty, + "repetition_penalty": repetition_penalty, + "early_stopping": True, + "use_cache": True, + } + prompts = ["John has doubts", "Maria has grave doubts"] + prompt_ids = [processor.get_prompt_ids(p) for p in prompts] + pt_transcription = [] + for i in range(batch_size): + inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) + inputs["input_features"] = input_features_[i].to(device) + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + pt_transcription.append(processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]) + inputs["input_features"] = input_features + del inputs["prompt_ids"] + del inputs["early_stopping"] + del inputs["use_cache"] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) + ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) + ort_to_np = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + } + + for name, dtype in zip(ort_names, ort_dtypes): + if name == "input_features": + inputs[name] = inputs[name].detach().cpu().numpy() + elif name == "vocab_mask": + inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) + elif name == "prefix_vocab_mask": + inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) + elif name == "decoder_input_ids": + # This logic handles the scenario for when prompts are not of the same size + # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1] + # The final decoder_input_ids will look as such after padding + # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] + # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token] + ort_prompts = [] + for i in range(batch_size): + ort_prompts.append(prompt_ids[i].tolist()) + max_len = max(len(p) for p in ort_prompts) + padded_prompts = [] + for p in ort_prompts: + padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))] + padded_prompts.append(padded_prompt + forced_decoder_ids) + inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype]) + elif name == "logits_processor": + inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + elif name == "cross_qk_layer_head": + inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) + elif name == "extra_decoding_ids": + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) + else: + inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) + ort_outputs = ort_session.run(None, inputs)[0] + + expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_mispelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" + expected_transcription_mispelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_options = { + expected_transcription_no_comma_prompt1, + expected_transcription_no_comma_prompt2, + expected_transcription_mispelled_prompt1, + expected_transcription_mispelled_prompt2, + } + ort_outputs = ort_session.run(None, inputs)[0] + ort_transcription = [] + for o in ort_outputs: + ort_transcription.append(processor.batch_decode(o, skip_special_tokens=True)[0]) + + parity = 1 + for i in range(batch_size): + parity *= ( + pt_transcription[i] in expected_transcription_options + and ort_transcription[i] in expected_transcription_options + ) + max_diff = 0 + + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) + + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") + + return max_diff From 9ffae5b0503199ac4a84e6537622d5d3d062088d Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 11 Mar 2024 18:51:02 +0000 Subject: [PATCH 2/9] Fix lint --- .../tools/transformers/models/whisper/whisper_helper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 003526c53c6fd..5acf4acf667d7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -546,14 +546,14 @@ def verify_onnx_multi_batch( ort_outputs = ort_session.run(None, inputs)[0] expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" - expected_transcription_mispelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" - expected_transcription_mispelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" expected_transcription_options = { expected_transcription_no_comma_prompt1, expected_transcription_no_comma_prompt2, - expected_transcription_mispelled_prompt1, - expected_transcription_mispelled_prompt2, + expected_transcription_misspelled_prompt1, + expected_transcription_misspelled_prompt2, } ort_outputs = ort_session.run(None, inputs)[0] ort_transcription = [] From b34a04af6f79598ee8968c5ac0d82155c55de7ce Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 15 Mar 2024 19:29:09 +0000 Subject: [PATCH 3/9] Modularize verify_onnx --- .../models/whisper/convert_to_onnx.py | 4 +- .../models/whisper/whisper_helper.py | 272 +++++++----------- 2 files changed, 110 insertions(+), 166 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 2d084a3017c2f..112e2fbb1bfdf 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -508,8 +508,8 @@ def main(argv=None): with torch.no_grad(): # Verify batched decoding with prompts for whisper openai implementation if args.model_impl == "openai" and args.use_forced_decoder_ids: - max_diff = WhisperHelper.verify_onnx_multi_batch( - args.model_name_or_path, cache_dir, ort_session, device + max_diff = WhisperHelper.verify_onnx( + args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True ) else: max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 5acf4acf667d7..392e16cb307fe 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Dict, Tuple, Union +import datasets import numpy as np import torch from float16 import float_to_float16_max_diff @@ -314,42 +315,26 @@ def optimize_onnx( m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) @staticmethod - def verify_onnx( - model_name_or_path: str, - cache_dir: str, - ort_session: InferenceSession, + def pt_transcription_for_verify_onnx( + ds: Union[datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset], + processor: WhisperProcessor, + pt_model: torch.nn.Module, device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, ): - """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - extra_kwargs = {} - if version.parse(transformers_version) >= version.parse("4.36.0"): - extra_kwargs["attn_implementation"] = "eager" - pt_model = WhisperForConditionalGeneration.from_pretrained( - model_name_or_path, cache_dir=cache_dir, **extra_kwargs - ).to(device) - processor = WhisperProcessor.from_pretrained(model_name_or_path) - config = WhisperConfig.from_pretrained(model_name_or_path) - - # Try to import `datasets` pip package - try: - from datasets import load_dataset - except Exception as e: - logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) - install_cmd = "pip install datasets" - logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") - os.system(install_cmd) - - from datasets import load_dataset - - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - - start_id = [config.decoder_start_token_id] # ex: [50258] - prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") - prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] - forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] - - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 + input_features_ = [] + if batch_size == 1: + input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + else: + input_features_ = [ + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + ] + assert len(input_features_) == batch_size + input_features = torch.cat((input_features_[0], input_features_[1])).to(device) + + max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -362,85 +347,70 @@ def verify_onnx( "early_stopping": True, "use_cache": True, } - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + if prompt_mode: + prompts = ["John has doubts", "Maria has grave doubts"] + prompt_ids = [processor.get_prompt_ids(p) for p in prompts] + pt_transcription = [] + pt_outputs = [] + for i in range(batch_size): + inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) + inputs["input_features"] = input_features_[i].to(device) + pt_output = pt_model.generate(**inputs).detach().cpu().numpy() + pt_outputs.append(pt_output) + pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0]) + inputs["input_features"] = input_features + del inputs["prompt_ids"] + else: + prompt_ids = [] + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]] + pt_outputs = list(pt_outputs) del inputs["early_stopping"] del inputs["use_cache"] - ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) - ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) - ort_to_np = { - "tensor(float)": np.float32, - "tensor(float16)": np.float16, - "tensor(int64)": np.int64, - "tensor(int32)": np.int32, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - } - - use_extra_decoding_ids = "extra_decoding_ids" in ort_names - for name, dtype in zip(ort_names, ort_dtypes): - if name == "input_features": - inputs[name] = inputs[name].detach().cpu().numpy() - elif name == "vocab_mask": - inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype]) - elif name == "prefix_vocab_mask": - inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) - elif name == "decoder_input_ids": - raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] - inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) - elif name == "logits_processor": - inputs[name] = np.array([1], dtype=ort_to_np[dtype]) - elif name == "cross_qk_layer_head": - inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) - elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) - elif name == "temperature": - inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) - else: - inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0][0] - - expected_transcription_no_comma = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - expected_transcription_with_comma = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - expected_transcription_with_quote_and_comma = ( - ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ) - expected_transcription_options = { - expected_transcription_no_comma, - expected_transcription_with_comma, - expected_transcription_with_quote_and_comma, - } - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] + return inputs, pt_transcription, pt_outputs, prompt_ids - parity = ( - pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options - ) - max_diff = 0 - - if not parity: - if pt_outputs.shape != ort_outputs.shape: - diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] - else: - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) - - if max_diff != 0: - logger.warning(f"PyTorch outputs: {pt_transcription}") - logger.warning(f"ONNX Runtime outputs: {ort_transcription}") - - return max_diff + @staticmethod + def select_transcription_options( + batch_size: int, + prompt_mode: bool, + ): + if batch_size > 1 and prompt_mode is True: + expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" + expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_options = { + expected_transcription_no_comma_prompt1, + expected_transcription_no_comma_prompt2, + expected_transcription_misspelled_prompt1, + expected_transcription_misspelled_prompt2, + } + else: + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + return expected_transcription_options @staticmethod - def verify_onnx_multi_batch( + def verify_onnx( model_name_or_path: str, cache_dir: str, ort_session: InferenceSession, device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, ): """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" extra_kwargs = {} @@ -464,42 +434,20 @@ def verify_onnx_multi_batch( from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - input_features_ = [ - processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, - processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, - ] - input_features = torch.cat((input_features_[0], input_features_[1])).to(device) + inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx( + ds, + processor, + pt_model, + device, + batch_size=batch_size, + prompt_mode=prompt_mode, + ) start_id = [config.decoder_start_token_id] # ex: [50258] prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] - batch_size, max_length, min_length, num_beams, num_return_sequences = 2, 30, 0, 1, 1 - length_penalty, repetition_penalty = 1.0, 1.0 - inputs = { - "input_features": input_features.to(device), - "max_length": max_length, - "min_length": min_length, - "num_beams": num_beams, - "num_return_sequences": num_return_sequences, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "early_stopping": True, - "use_cache": True, - } - prompts = ["John has doubts", "Maria has grave doubts"] - prompt_ids = [processor.get_prompt_ids(p) for p in prompts] - pt_transcription = [] - for i in range(batch_size): - inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) - inputs["input_features"] = input_features_[i].to(device) - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() - pt_transcription.append(processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]) - inputs["input_features"] = input_features - del inputs["prompt_ids"] - del inputs["early_stopping"] - del inputs["use_cache"] ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) ort_to_np = { @@ -511,6 +459,7 @@ def verify_onnx_multi_batch( "tensor(uint8)": np.uint8, } + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() @@ -519,20 +468,24 @@ def verify_onnx_multi_batch( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - # This logic handles the scenario for when prompts are not of the same size - # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1] - # The final decoder_input_ids will look as such after padding - # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] - # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token] - ort_prompts = [] - for i in range(batch_size): - ort_prompts.append(prompt_ids[i].tolist()) - max_len = max(len(p) for p in ort_prompts) - padded_prompts = [] - for p in ort_prompts: - padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))] - padded_prompts.append(padded_prompt + forced_decoder_ids) - inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype]) + if not prompt_mode: + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) + else: + # This logic handles the scenario for when prompts are not of the same size + # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1] + # The final decoder_input_ids will look as such after padding + # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] + # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token] + ort_prompts = [] + for i in range(batch_size): + ort_prompts.append(decoder_prompt_ids[i].tolist()) + max_len = max(len(p) for p in ort_prompts) + padded_prompts = [] + for p in ort_prompts: + padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))] + padded_prompts.append(padded_prompt + forced_decoder_ids) + inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": @@ -544,21 +497,10 @@ def verify_onnx_multi_batch( else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0] - - expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" - expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" - expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" - expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" - expected_transcription_options = { - expected_transcription_no_comma_prompt1, - expected_transcription_no_comma_prompt2, - expected_transcription_misspelled_prompt1, - expected_transcription_misspelled_prompt2, - } - ort_outputs = ort_session.run(None, inputs)[0] ort_transcription = [] for o in ort_outputs: ort_transcription.append(processor.batch_decode(o, skip_special_tokens=True)[0]) + expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode) parity = 1 for i in range(batch_size): @@ -569,11 +511,13 @@ def verify_onnx_multi_batch( max_diff = 0 if not parity: - if pt_outputs.shape != ort_outputs.shape: - diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] - else: - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + for i in range(batch_size): + if pt_outputs[i].shape != ort_outputs[i].shape: + diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])] + else: + diff = pt_outputs[i] - ort_outputs[i] + max_diff_i = max(diff.min(), diff.max(), key=abs) + max_diff = max(max_diff, max_diff_i) if max_diff != 0: logger.warning(f"PyTorch outputs: {pt_transcription}") From 44efe35303b6bef7b5c7e82e92aa6c3b0b0792fb Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 28 Mar 2024 07:05:58 +0000 Subject: [PATCH 4/9] minor edits --- .../models/whisper/whisper_helper.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 392e16cb307fe..936289e3b8518 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -9,7 +9,6 @@ from pathlib import Path from typing import Dict, Tuple, Union -import datasets import numpy as np import torch from float16 import float_to_float16_max_diff @@ -316,13 +315,24 @@ def optimize_onnx( @staticmethod def pt_transcription_for_verify_onnx( - ds: Union[datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset], processor: WhisperProcessor, pt_model: torch.nn.Module, device: torch.device, batch_size: int = 1, prompt_mode: bool = False, ): + # Try to import `datasets` pip package + try: + from datasets import load_dataset + except Exception as e: + logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) + install_cmd = "pip install datasets" + logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") + os.system(install_cmd) + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features_ = [] if batch_size == 1: input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features @@ -332,7 +342,7 @@ def pt_transcription_for_verify_onnx( processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, ] assert len(input_features_) == batch_size - input_features = torch.cat((input_features_[0], input_features_[1])).to(device) + input_features = torch.cat((input_features_[0], input_features_[1])) max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 @@ -422,20 +432,7 @@ def verify_onnx( processor = WhisperProcessor.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path) - # Try to import `datasets` pip package - try: - from datasets import load_dataset - except Exception as e: - logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) - install_cmd = "pip install datasets" - logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") - os.system(install_cmd) - - from datasets import load_dataset - - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx( - ds, processor, pt_model, device, @@ -517,7 +514,7 @@ def verify_onnx( else: diff = pt_outputs[i] - ort_outputs[i] max_diff_i = max(diff.min(), diff.max(), key=abs) - max_diff = max(max_diff, max_diff_i) + max_diff = max(max_diff, max_diff_i) if max_diff != 0: logger.warning(f"PyTorch outputs: {pt_transcription}") From d225e4d5fecdba0e86c37e15ae8e3ff749cdef06 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 29 Mar 2024 18:30:43 +0000 Subject: [PATCH 5/9] remove looping logic --- .../tools/transformers/models/whisper/whisper_helper.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 936289e3b8518..988735f4e211d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -363,6 +363,9 @@ def pt_transcription_for_verify_onnx( prompt_ids = [processor.get_prompt_ids(p) for p in prompts] pt_transcription = [] pt_outputs = [] + # The looping for model.generate is necessary here due to the limitation as per + # https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids + # prompt_ids input requires a tensor of rank 1 for i in range(batch_size): inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) inputs["input_features"] = input_features_[i].to(device) @@ -493,10 +496,8 @@ def verify_onnx( inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0] - ort_transcription = [] - for o in ort_outputs: - ort_transcription.append(processor.batch_decode(o, skip_special_tokens=True)[0]) + ort_outputs = ort_session.run(None, inputs)[0][:, 0, :] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode) parity = 1 From 94a0ddb9eb2b526f243807a37f8ffb4b6e1e50ae Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Sat, 30 Mar 2024 00:15:26 +0000 Subject: [PATCH 6/9] Avoid model cloning --- .../models/whisper/whisper_encoder_decoder_init.py | 7 +++---- .../models/whisper/whisper_openai_helper.py | 10 +++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 832f692e9980d..705764630ab1a 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -import copy import logging import os import tempfile @@ -51,12 +50,13 @@ def forward( self, encoder_input_ids: torch.Tensor, decoder_input_ids: torch.Tensor = None, + remove_hooks: bool = False, ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -132,8 +132,7 @@ def export_onnx( input_list = inputs.to_list() # TODO : Investigate whether copy of model if needed - cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) + out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 941f61cf7cc29..a69b27a74cc29 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -30,6 +30,7 @@ def forward( tokens, audio_features, past=None, + remove_hooks=False, ): # Create a kv_cache for past_values past_kv_cache = dict() @@ -44,8 +45,9 @@ def forward( past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + hooks = None if not self.kv_cache: - self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks() logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) @@ -73,4 +75,10 @@ def forward( present_self = [ present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self ] + + # Remove forward hooks to avoid model cloning step + if hooks is not None and remove_hooks is True: + self.kv_cache = {} + for hook in hooks: + hook.remove() return logits, present_self From e7a48be6ba733ef498f9b86cd92df05d564eff36 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Sat, 30 Mar 2024 00:17:38 +0000 Subject: [PATCH 7/9] lint --- .../models/whisper/whisper_encoder_decoder_init.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 705764630ab1a..c1b4d278fdf0c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -56,7 +56,9 @@ def forward( # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks) + decinit_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks + ) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) From 068f673b35dbc46341eeb4e7718bffb2ce924f8b Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Mon, 1 Apr 2024 09:38:21 -0700 Subject: [PATCH 8/9] Apply suggestions from code review Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- .../models/whisper/whisper_encoder_decoder_init.py | 1 - .../tools/transformers/models/whisper/whisper_helper.py | 6 +++--- .../transformers/models/whisper/whisper_openai_helper.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index c1b4d278fdf0c..fab2a2aa4c8a8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -133,7 +133,6 @@ def export_onnx( ) input_list = inputs.to_list() - # TODO : Investigate whether copy of model if needed out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 988735f4e211d..b3dcbf0f5289a 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -388,7 +388,7 @@ def select_transcription_options( batch_size: int, prompt_mode: bool, ): - if batch_size > 1 and prompt_mode is True: + if batch_size > 1 and prompt_mode: expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" @@ -432,8 +432,8 @@ def verify_onnx( pt_model = WhisperForConditionalGeneration.from_pretrained( model_name_or_path, cache_dir=cache_dir, **extra_kwargs ).to(device) - processor = WhisperProcessor.from_pretrained(model_name_or_path) - config = WhisperConfig.from_pretrained(model_name_or_path) + processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir) + config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx( processor, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index a69b27a74cc29..849c3059f21f7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -77,7 +77,7 @@ def forward( ] # Remove forward hooks to avoid model cloning step - if hooks is not None and remove_hooks is True: + if hooks is not None and remove_hooks: self.kv_cache = {} for hook in hooks: hook.remove() From d33fb3a2945f6aeb6df2f37c3780f6cc1d342e67 Mon Sep 17 00:00:00 2001 From: Shubham Bhokare Date: Tue, 2 Apr 2024 20:23:32 +0000 Subject: [PATCH 9/9] Add cache dir --- .../python/tools/transformers/models/whisper/whisper_chain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 14691da4ad643..b52cab755db78 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -51,8 +51,8 @@ def chain_model(args): decoder_model = onnx.load_model(args.decoder_path, load_external_data=True) decoder_model.graph.name = "decoder subgraph" - config = WhisperConfig.from_pretrained(args.model_name_or_path) - tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) # Create inputs/outputs for WhisperBeamSearch op temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"