Skip to content

Commit

Permalink
Update Whisper export with beam search (#19322)
Browse files Browse the repository at this point in the history
### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](#19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](#19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
  • Loading branch information
kunal-vaishnavi authored and rachguo committed Jan 30, 2024
1 parent ed8b7a6 commit 5c8fada
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);

bool is_unidirectional = false;
bool is_dmmha_packing = (key == nullptr && value == nullptr);
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
Expand All @@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
num_heads_,
mask_filter_value_,
scale_,
is_unidirectional,
past_present_share_buffer_,
is_dmmha_packing, // dmmha_packing
device_prop.maxThreadsPerBlock));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w
Export + Optimize for FP16 and GPU
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
```

Export + Quantize for INT8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device)
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import numpy as np
import torch
from packaging import version
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from transformers import __version__ as transformers_version
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
Expand Down Expand Up @@ -88,7 +90,10 @@ def load_model(
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)

Expand Down Expand Up @@ -262,11 +267,17 @@ def optimize_onnx(
@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)

Expand Down

0 comments on commit 5c8fada

Please sign in to comment.