Skip to content

Commit

Permalink
Support ONNX export of OpenAi Whisper model (#17316)
Browse files Browse the repository at this point in the history
Build from source and run the command below

Example, converting whisper-base
`
python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m
openai/whisper-base --model_impl openai -e -o -w --chain_model --output
./demo`
  • Loading branch information
shubhambhokare1 authored Feb 9, 2024
1 parent 1007d8f commit 90cf037
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 29 deletions.
239 changes: 229 additions & 10 deletions onnxruntime/python/tools/transformers/fusion_bart_attention.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def parse_arguments(argv=None):
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
)

parser.add_argument(
"--model_impl",
required=False,
default="hf",
choices=["hf", "openai"],
type=str,
help="Select implementation for export of encoder and decoder subgraphs",
)

parser.add_argument(
"--cache_dir",
required=False,
Expand Down Expand Up @@ -300,6 +309,7 @@ def parse_arguments(argv=None):

def export_onnx_models(
model_name_or_path,
model_impl,
cache_dir,
output_dir,
use_gpu,
Expand All @@ -321,7 +331,7 @@ def export_onnx_models(
device = torch.device("cuda:0" if use_gpu else "cpu")

models = WhisperHelper.load_model(
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
)
config = models["decoder"].config

Expand Down Expand Up @@ -431,6 +441,7 @@ def main(argv=None):

output_paths = export_onnx_models(
args.model_name_or_path,
args.model_impl,
cache_dir,
output_dir,
args.use_gpu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import WhisperConfig, file_utils
from whisper_openai_helper import WhisperDecoderInitOpenai

from onnxruntime import InferenceSession

Expand Down Expand Up @@ -67,17 +68,28 @@ def forward(
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with past key values"""

def __init__(self, decoder, config):
def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None):
super().__init__()
self.decoder = decoder
self.config = config
self.model_impl = model_impl
if model is not None:
self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)

def forward(self, decoder_input_ids, *past):
encoder_outputs = file_utils.ModelOutput()
dummy_encoder_hidden_states = torch.randn((decoder_input_ids.shape[0], 3000, int(self.config.d_model)))
encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states
encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
encoder_outputs["attentions"] = None

if self.model_impl == "openai":
dummy_encoder_hidden_states.unsqueeze(0)
dec_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, dummy_encoder_hidden_states, past=past
)
return dec_out, present

if len(past) == 0:
past_key_values = None
else:
Expand Down Expand Up @@ -158,7 +170,7 @@ def create_dummy(
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
past_decode_sequence_length,
head_size,
]

Expand Down Expand Up @@ -213,7 +225,7 @@ def export_onnx(
decoder.config,
batch_size=2,
encode_sequence_length=3000,
past_decode_sequence_length=5 if isinstance(decoder, WhisperDecoder) else 0,
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
class WhisperEncoder(torch.nn.Module):
"""Whisper encoder outputs only the last hidden state"""

def __init__(self, encoder, config: WhisperConfig):
def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"):
super().__init__()
self.encoder = encoder
self.config = config
self.model_impl = model_impl

def forward(self, input_features):
if self.model_impl == "openai":
return self.encoder(input_features)
return self.encoder.model.encoder(input_features)[0]


Expand All @@ -40,7 +43,11 @@ def __init__(self, input_features):

@staticmethod
def create_dummy(
batch_size: int, sequence_length: int, feature_size: int, device: torch.device, use_int32_inputs: bool
batch_size: int,
sequence_length: int,
feature_size: int,
device: torch.device,
use_int32_inputs: bool = False,
):
"""Create dummy inputs for Whisper encoder.
Expand All @@ -61,9 +68,9 @@ def create_dummy(
return WhisperEncoderInputs(input_features)

def to_list(self) -> List:
if self.input_features is None:
if self.input_ids is None:
return []
return [self.input_features]
return [self.input_ids]


class WhisperEncoderHelper:
Expand All @@ -74,6 +81,7 @@ def export_onnx(
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export encoder to ONNX
Expand All @@ -90,6 +98,7 @@ def export_onnx(
sequence_length=3000,
feature_size=config.num_mel_bins,
device=device,
use_int32_inputs=use_int32_inputs,
)

Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# --------------------------------------------------------------------------

import copy
import logging
import os
import tempfile
Expand All @@ -19,6 +20,7 @@
from transformers import WhisperConfig
from whisper_decoder import WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderInputs
from whisper_openai_helper import WhisperDecoderInitOpenai

from onnxruntime import InferenceSession

Expand All @@ -34,11 +36,16 @@ def __init__(
decoder: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
model_impl: str = "hf",
model: torch.nn.Module = None,
):
super().__init__()
self.config = config
self.whisper_encoder = WhisperEncoder(encoder, config)
self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl)
self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id)
if model is not None:
self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)
self.model_impl = model_impl

def forward(
self,
Expand All @@ -47,9 +54,14 @@ def forward(
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1])
present = present_self + present_cross
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
return decinit_out, encoder_hidden_states, present
else:
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1])
present = present_self + present_cross
return decinit_out[0], encoder_hidden_states, present


Expand All @@ -72,7 +84,6 @@ def create_dummy(
sequence_length=3000,
feature_size=config.num_mel_bins,
device=device,
use_int32_inputs=use_int32_inputs,
)
decoder_input_ids = None
if use_decoder_input_ids:
Expand Down Expand Up @@ -120,7 +131,9 @@ def export_onnx(
)
input_list = inputs.to_list()

out = model(inputs.encoder_input_ids, inputs.decoder_input_ids)
# TODO : Investigate whether copy of model if needed
cloned_model = copy.deepcopy(model).to(device)
out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids)
present = out[2]
present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,49 @@ def get_onnx_path(
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
return os.path.join(directory, model_name + ".onnx")

@staticmethod
def load_model_openai(
model_name_or_path: str,
cache_dir: str,
device: torch.device,
) -> torch.nn.Module:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
Args:
model_name_or_path (str): pretrained model name or path
cache_dir (str): cache directory
device (torch.device): device to run the model
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
from whisper import _ALIGNMENT_HEADS, _MODELS, _download
from whisper.model import ModelDimensions, Whisper

in_memory = False

model_name = model_name_or_path.split("/")[-1][8:]
checkpoint_file, alignment_heads = None, None
if model_name in _MODELS:
checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
alignment_heads = _ALIGNMENT_HEADS[model_name]

with open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)

@staticmethod
def load_model(
model_name_or_path: str,
model_impl: str,
cache_dir: str,
device: torch.device,
merge_encoder_and_decoder_init: bool = True,
Expand All @@ -94,18 +134,29 @@ def load_model(
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)

if model_impl == "openai":
openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device)
model_encoder, model_decoder = openai_model.encoder, openai_model.decoder
passed_model = openai_model
else:
model_encoder, model_decoder = model, model
passed_model = None

if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)

decoder = WhisperDecoder(model, model.config)
decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model)
decoder.eval().to(device)

if merge_encoder_and_decoder_init:
encoder_decoder_init = WhisperEncoderDecoderInit(
model,
model,
model_encoder,
model_decoder,
model.config,
decoder_start_token_id=None,
model_impl=model_impl,
model=passed_model,
)
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import logging

import torch

logger = logging.getLogger(__name__)


class WhisperDecoderInitOpenai(torch.nn.Module):
"""WhisperDecoderInit for Openai."""

def __init__(
self,
model: torch.nn.Module,
decoder: torch.nn.Module,
):
super().__init__()
self.whisper_model = model
self.whisper_decoder = decoder
self.kv_cache = {}

@torch.no_grad()
def forward(
self,
tokens,
audio_features,
past=None,
):
# Create a kv_cache for past_values
past_kv_cache = dict()
if past is not None:
# Convert past values from 4D to 3D
past = [torch.transpose(val, 1, 2) for val in past]
past = [val.reshape(val.shape[:2] + (-1,)) for val in past]
half_idx = len(past) // 2
for idx, block in enumerate(self.whisper_decoder.blocks):
past_kv_cache[block.attn.key] = past[2 * idx]
past_kv_cache[block.attn.value] = past[2 * idx + 1]
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]

if not self.kv_cache:
self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks()

logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)

# Add concat node for past values
if past is not None:
for block in self.whisper_decoder.blocks:
self.kv_cache[block.attn.key] = torch.cat(
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1
).detach()
self.kv_cache[block.attn.value] = torch.cat(
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1
).detach()

present_self, present_cross = [], []
# Group self and cross values
for block in self.whisper_decoder.blocks:
present_self.append(self.kv_cache[block.attn.key])
present_self.append(self.kv_cache[block.attn.value])
if past is None:
present_cross.append(self.kv_cache[block.cross_attn.key])
present_cross.append(self.kv_cache[block.cross_attn.value])

present_self = present_self + present_cross
# Add reshape and transpose ops to convert from 3D to 4D
present_self = [
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
]
return logits, present_self
Loading

0 comments on commit 90cf037

Please sign in to comment.