From 6b61dcb046c870f106d1289890eabaf679cb8116 Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:01:48 -0700 Subject: [PATCH] Export of Openai Whisper with batched prompts (#19854) Adds an example to demonstrate the export of openai whipser implemenation with batch_size > 1 and addition of prompts for each audio snippet. Also handles the scenario for when prompts are not of the same size. For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1], the final decoder_input_ids will look as such after padding: `[prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]` --------- Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- .../models/whisper/convert_to_onnx.py | 8 +- .../models/whisper/whisper_chain.py | 4 +- .../models/whisper/whisper_decoder.py | 4 +- .../whisper/whisper_encoder_decoder_init.py | 10 +- .../models/whisper/whisper_helper.py | 192 +++++++++++++----- .../models/whisper/whisper_openai_helper.py | 10 +- 6 files changed, 166 insertions(+), 62 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 5921e4ed42936..bdd49b9f70a4d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -506,7 +506,13 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) + # Verify batched decoding with prompts for whisper openai implementation + if args.model_impl == "openai" and args.use_forced_decoder_ids: + max_diff = WhisperHelper.verify_onnx( + args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True + ) + else: + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 0b128f122e0f4..be05ebc9d5dac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -51,8 +51,8 @@ def chain_model(args): decoder_model = onnx.load_model(args.decoder_path, load_external_data=True) decoder_model.graph.name = "decoder subgraph" - config = WhisperConfig.from_pretrained(args.model_name_or_path) - tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) # Create inputs/outputs for WhisperBeamSearch op temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 93fd64c9eb7d3..5da235d72ca0b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -126,6 +126,7 @@ def create_dummy( device: torch.device, float16: bool = False, use_int32_inputs: bool = False, + model_impl: str = "hf", ): # -> WhisperDecoderInputs: """Create dummy inputs for WhisperDecoder. @@ -170,7 +171,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - encode_sequence_length, + encode_sequence_length if model_impl == "hf" else past_decode_sequence_length, head_size, ] @@ -228,6 +229,7 @@ def export_onnx( past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, + model_impl=decoder.model_impl, ) input_list = inputs.to_list() diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 832f692e9980d..fab2a2aa4c8a8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -import copy import logging import os import tempfile @@ -51,12 +50,15 @@ def forward( self, encoder_input_ids: torch.Tensor, decoder_input_ids: torch.Tensor = None, + remove_hooks: bool = False, ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + decinit_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks + ) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -131,9 +133,7 @@ def export_onnx( ) input_list = inputs.to_list() - # TODO : Investigate whether copy of model if needed - cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) + out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index adf7f69470ae7..9fb51dd9b43c0 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -314,22 +314,13 @@ def optimize_onnx( m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) @staticmethod - def verify_onnx( - model_name_or_path: str, - cache_dir: str, - ort_session: InferenceSession, + def pt_transcription_for_verify_onnx( + processor: WhisperProcessor, + pt_model: torch.nn.Module, device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, ): - """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - extra_kwargs = {} - if version.parse(transformers_version) >= version.parse("4.36.0"): - extra_kwargs["attn_implementation"] = "eager" - pt_model = WhisperForConditionalGeneration.from_pretrained( - model_name_or_path, cache_dir=cache_dir, **extra_kwargs - ).to(device) - processor = WhisperProcessor.from_pretrained(model_name_or_path) - config = WhisperConfig.from_pretrained(model_name_or_path) - # Try to import `datasets` pip package try: from datasets import load_dataset @@ -342,14 +333,18 @@ def verify_onnx( from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - - start_id = [config.decoder_start_token_id] # ex: [50258] - prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") - prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] - forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] - - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 + input_features_ = [] + if batch_size == 1: + input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + else: + input_features_ = [ + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + ] + assert len(input_features_) == batch_size + input_features = torch.cat((input_features_[0], input_features_[1])) + + max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -362,10 +357,97 @@ def verify_onnx( "early_stopping": True, "use_cache": True, } - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + if prompt_mode: + prompts = ["John has doubts", "Maria has grave doubts"] + prompt_ids = [processor.get_prompt_ids(p) for p in prompts] + pt_transcription = [] + pt_outputs = [] + # The looping for model.generate is necessary here due to the limitation as per + # https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids + # prompt_ids input requires a tensor of rank 1 + for i in range(batch_size): + inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) + inputs["input_features"] = input_features_[i].to(device) + pt_output = pt_model.generate(**inputs).detach().cpu().numpy() + pt_outputs.append(pt_output) + pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0]) + inputs["input_features"] = input_features + del inputs["prompt_ids"] + else: + prompt_ids = [] + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]] + pt_outputs = list(pt_outputs) del inputs["early_stopping"] del inputs["use_cache"] + return inputs, pt_transcription, pt_outputs, prompt_ids + + @staticmethod + def select_transcription_options( + batch_size: int, + prompt_mode: bool, + ): + if batch_size > 1 and prompt_mode: + expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" + expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_options = { + expected_transcription_no_comma_prompt1, + expected_transcription_no_comma_prompt2, + expected_transcription_misspelled_prompt1, + expected_transcription_misspelled_prompt2, + } + else: + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + return expected_transcription_options + + @staticmethod + def verify_onnx( + model_name_or_path: str, + cache_dir: str, + ort_session: InferenceSession, + device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, + ): + """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) + processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir) + config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) + + inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx( + processor, + pt_model, + device, + batch_size=batch_size, + prompt_mode=prompt_mode, + ) + + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) ort_to_np = { @@ -386,8 +468,24 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] - inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) + if not prompt_mode: + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) + else: + # This logic handles the scenario for when prompts are not of the same size + # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1] + # The final decoder_input_ids will look as such after padding + # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] + # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token] + ort_prompts = [] + for i in range(batch_size): + ort_prompts.append(decoder_prompt_ids[i].tolist()) + max_len = max(len(p) for p in ort_prompts) + padded_prompts = [] + for p in ort_prompts: + padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))] + padded_prompts.append(padded_prompt + forced_decoder_ids) + inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": @@ -398,36 +496,26 @@ def verify_onnx( inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0][0] - - expected_transcription_no_comma = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - expected_transcription_with_comma = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - expected_transcription_with_quote_and_comma = ( - ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ) - expected_transcription_options = { - expected_transcription_no_comma, - expected_transcription_with_comma, - expected_transcription_with_quote_and_comma, - } - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - - parity = ( - pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options - ) + ort_outputs = ort_session.run(None, inputs)[0][:, 0, :] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode) + + parity = 1 + for i in range(batch_size): + parity *= ( + pt_transcription[i] in expected_transcription_options + and ort_transcription[i] in expected_transcription_options + ) max_diff = 0 if not parity: - if pt_outputs.shape != ort_outputs.shape: - diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] - else: - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + for i in range(batch_size): + if pt_outputs[i].shape != ort_outputs[i].shape: + diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])] + else: + diff = pt_outputs[i] - ort_outputs[i] + max_diff_i = max(diff.min(), diff.max(), key=abs) + max_diff = max(max_diff, max_diff_i) if max_diff != 0: logger.warning(f"PyTorch outputs: {pt_transcription}") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 941f61cf7cc29..849c3059f21f7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -30,6 +30,7 @@ def forward( tokens, audio_features, past=None, + remove_hooks=False, ): # Create a kv_cache for past_values past_kv_cache = dict() @@ -44,8 +45,9 @@ def forward( past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + hooks = None if not self.kv_cache: - self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks() logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) @@ -73,4 +75,10 @@ def forward( present_self = [ present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self ] + + # Remove forward hooks to avoid model cloning step + if hooks is not None and remove_hooks: + self.kv_cache = {} + for hook in hooks: + hook.remove() return logits, present_self