-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support ONNX export of OpenAi Whisper model (#17316)
Build from source and run the command below Example, converting whisper-base ` python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-base --model_impl openai -e -o -w --chain_model --output ./demo`
- Loading branch information
1 parent
1007d8f
commit 90cf037
Showing
10 changed files
with
474 additions
and
29 deletions.
There are no files selected for viewing
239 changes: 229 additions & 10 deletions
239
onnxruntime/python/tools/transformers/fusion_bart_attention.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
import logging | ||
|
||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class WhisperDecoderInitOpenai(torch.nn.Module): | ||
"""WhisperDecoderInit for Openai.""" | ||
|
||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
decoder: torch.nn.Module, | ||
): | ||
super().__init__() | ||
self.whisper_model = model | ||
self.whisper_decoder = decoder | ||
self.kv_cache = {} | ||
|
||
@torch.no_grad() | ||
def forward( | ||
self, | ||
tokens, | ||
audio_features, | ||
past=None, | ||
): | ||
# Create a kv_cache for past_values | ||
past_kv_cache = dict() | ||
if past is not None: | ||
# Convert past values from 4D to 3D | ||
past = [torch.transpose(val, 1, 2) for val in past] | ||
past = [val.reshape(val.shape[:2] + (-1,)) for val in past] | ||
half_idx = len(past) // 2 | ||
for idx, block in enumerate(self.whisper_decoder.blocks): | ||
past_kv_cache[block.attn.key] = past[2 * idx] | ||
past_kv_cache[block.attn.value] = past[2 * idx + 1] | ||
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] | ||
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] | ||
|
||
if not self.kv_cache: | ||
self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() | ||
|
||
logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) | ||
|
||
# Add concat node for past values | ||
if past is not None: | ||
for block in self.whisper_decoder.blocks: | ||
self.kv_cache[block.attn.key] = torch.cat( | ||
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 | ||
).detach() | ||
self.kv_cache[block.attn.value] = torch.cat( | ||
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 | ||
).detach() | ||
|
||
present_self, present_cross = [], [] | ||
# Group self and cross values | ||
for block in self.whisper_decoder.blocks: | ||
present_self.append(self.kv_cache[block.attn.key]) | ||
present_self.append(self.kv_cache[block.attn.value]) | ||
if past is None: | ||
present_cross.append(self.kv_cache[block.cross_attn.key]) | ||
present_cross.append(self.kv_cache[block.cross_attn.value]) | ||
|
||
present_self = present_self + present_cross | ||
# Add reshape and transpose ops to convert from 3D to 4D | ||
present_self = [ | ||
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self | ||
] | ||
return logits, present_self |
Oops, something went wrong.