From 5c8fadabe86d22e3a635a4df2481948d44871249 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:59:15 -0800 Subject: [PATCH] Update Whisper export with beam search (#19322) ### Description This PR updates the Whisper export with beam search by adding the following. - Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the Whisper with beam search model - Sets the default PyTorch attention implementation to `eager` to allow existing attention fusions to continue working - Re-uses the cache directory when loading the PyTorch model to reduce memory used on disk - Adds `--disable_auto_mixed_precision` to the example FP16 export command ### Motivation and Context - [This PR](https://github.com/microsoft/onnxruntime/pull/19112) added the `is_unidirectional` parameter to `CheckInputs`, but it was not provided when checking the inputs in `DecoderMaskedMultiHeadAttention`. - [This PR](https://github.com/microsoft/onnxruntime/pull/19200) explains the reasoning behind why `eager` is used to load the `WhisperAttention` class. - By re-using the cache directory for loading the PyTorch model, only one copy of the PyTorch model is saved on disk instead of two copies. - By providing this flag, there will be less Cast nodes in the Whisper with beam search model to switch between FP16 and FP32 precision. --- .../bert/decoder_masked_multihead_attention.cc | 2 ++ .../tools/transformers/models/whisper/README.md | 4 ++-- .../models/whisper/convert_to_onnx.py | 2 +- .../transformers/models/whisper/whisper_helper.py | 15 +++++++++++++-- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index a9b60da0c96ca..66c0aceaed1e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + bool is_unidirectional = false; bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* num_heads_, mask_filter_value_, scale_, + is_unidirectional, past_present_share_buffer_, is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 8ff5c8a6e1de0..02100266200f8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 50637b772c233..e15a12c07bed7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -478,7 +478,7 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device) + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 8c22cd5e745b3..a4bef1f06b4fe 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -12,7 +12,9 @@ import numpy as np import torch +from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor +from transformers import __version__ as transformers_version from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper @@ -88,7 +90,10 @@ def load_model( Returns: Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. """ - model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) @@ -262,11 +267,17 @@ def optimize_onnx( @staticmethod def verify_onnx( model_name_or_path: str, + cache_dir: str, ort_session: InferenceSession, device: torch.device, ): """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) processor = WhisperProcessor.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path)