diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py index e738de0b6e9e3..812cdd1d70b9f 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention_openai.py @@ -127,8 +127,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = output break - graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) + graph_input_names = set([node.name for node in self.model.graph().input]) + graph_output_names = set([node.name for node in self.model.graph().output]) v_nodes = self.model.match_parent_path( matmul_qkv, @@ -152,13 +152,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if v_nodes is not None: (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - #present_v = add_v.output[0] add_v_children = self.model.get_children(add_v) for child in add_v_children: if child.op_type == "Reshape": - #if child.output[0] in graph_output_names: - #present_v = child.output[0] reshape_v_children = self.model.get_children(child) for reshape_child in reshape_v_children: if reshape_child.op_type == "Transpose": @@ -205,9 +202,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): present_v = present_v if present_v in graph_output_names else "" qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) - qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0] - ) + qk_nodes_2 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 @@ -256,13 +251,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_with_bias present_k = matmul_k.output[0] mat_k_out_tmp = matmul_k.output[0] + "_temp" - #matmul_k.output[0] = matmul_k.output[0] + "_temp" matmul_k_children = self.model.get_children(matmul_k) for child in matmul_k_children: if child.op_type == "Reshape": - #if child.output[0] in graph_output_names: - # present_k = child.output[0] reshape_k_children = self.model.get_children(child) for reshape_child in reshape_k_children: if reshape_child.op_type == "Transpose": @@ -285,8 +277,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if reshape_parent.op_type == "Transpose": if reshape_parent.input[0] in graph_input_names: past_k = reshape_parent.input[0] - #else: - # matmul_k.output[0] = mat_k_out_tmp elif k_nodes_no_bias is not None: @@ -328,7 +318,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - ''' + """ if not past_k and not self.check_runtime_shape_path( reshape_qkv_2, reshape_qkv_1, @@ -338,7 +328,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input, ): return - ''' + """ three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() one_root_input = ( @@ -381,7 +371,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) if mask_nodes_whisper is not None: pass - #mask_index = mask_nodes_whisper[0].output[-1] + # mask_index = mask_nodes_whisper[0].output[-1] elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bda6bbcf3398a..31ab1a4b4d252 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -331,12 +331,8 @@ def export_onnx_models( device = torch.device("cuda:0" if use_gpu else "cpu") models = WhisperHelper.load_model( - model_name_or_path, - model_impl, - cache_dir, - device, - merge_encoder_and_decoder_init, - state_dict_path + model_name_or_path, model_impl, cache_dir, device, + merge_encoder_and_decoder_init, state_dict_path ) config = models["decoder"].config diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 5ecdedaab0252..2590abd4dcdfd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -68,7 +68,7 @@ def forward( class WhisperDecoder(torch.nn.Module): """A Whisper decoder with past key values""" - def __init__(self, decoder, config, model_impl: str = 'hf', model=None): + def __init__(self, decoder, config, model_impl: str = "hf", model=None): super().__init__() self.decoder = decoder self.config = config @@ -83,9 +83,11 @@ def forward(self, decoder_input_ids, *past): encoder_outputs["hidden_states"] = dummy_encoder_hidden_states encoder_outputs["attentions"] = None - if self.model_impl == 'openai': + if self.model_impl == "openai": dummy_encoder_hidden_states.unsqueeze(0) - dec_out, present = self.whisper_decoder_openai_init(decoder_input_ids, dummy_encoder_hidden_states, past=past) + dec_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, dummy_encoder_hidden_states, past=past + ) return dec_out, present if len(past) == 0: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 6c8e0d21a2a0b..93281848a5c9c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -25,7 +25,7 @@ class WhisperEncoder(torch.nn.Module): """Whisper encoder outputs only the last hidden state""" - def __init__(self, encoder, config: WhisperConfig, model_impl: str = 'hf'): + def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"): super().__init__() self.encoder = encoder self.config = config diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 6886a18726a0e..296f44d3d66c8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,10 +4,10 @@ # license information. # -------------------------------------------------------------------------- +import copy import logging import os import tempfile -import copy from pathlib import Path from typing import List, Optional @@ -36,7 +36,7 @@ def __init__( decoder: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, - model_impl: str = 'hf', + model_impl: str = "hf", model: torch.nn.Module = None, ): super().__init__() @@ -55,7 +55,7 @@ def forward( ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) - if self.model_impl == 'openai': + if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) return decinit_out, encoder_hidden_states, present diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 853af21599c43..3955d095ec632 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -4,9 +4,9 @@ # license information. # -------------------------------------------------------------------------- +import io import logging import os -import io import sys from pathlib import Path from typing import Dict, Tuple, Union @@ -14,14 +14,12 @@ import numpy as np import torch from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor +from whisper import _MODELS, _ALIGNMENT_HEADS, _download +from whisper.model import Whisper, ModelDimensions from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper -from whisper.model import Whisper, ModelDimensions -from whisper import _MODELS, _ALIGNMENT_HEADS -from whisper import _download - from onnxruntime import InferenceSession sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) @@ -94,16 +92,14 @@ def load_model_openai( in_memory = False - model_name = model_name_or_path.split('/')[-1][8:] + model_name = model_name_or_path.split("/")[-1][8:] checkpoint_file = None if model_name in _MODELS: checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) alignment_heads = _ALIGNMENT_HEADS[model_name] - with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") - ) as fp: + with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp: checkpoint = torch.load(fp, map_location=device) del checkpoint_file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 9606c920b996e..2be5a327e15e5 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -26,6 +26,7 @@ class WhisperDecoderInitOpenai(torch.nn.Module): """WhisperDecoderInit for Openai.""" + def __init__( self, model: torch.nn.Module, @@ -43,13 +44,12 @@ def forward( audio_features, past=None, ): - # Create a kv_cache for past_values past_kv_cache = dict() if past is not None: # Convert past values from 4D to 3D past = [torch.transpose(val, 1, 2) for val in past] - past = [val.reshape(val.shape[:2] + (-1, )) for val in past] + past = [val.reshape(val.shape[:2] + (-1,)) for val in past] half_idx = len(past) // 2 for idx, block in enumerate(self.whisper_decoder.blocks): past_kv_cache[block.attn.key] = past[2 * idx] @@ -65,8 +65,12 @@ def forward( # Add concat node for past values if past is not None: for idx, block in enumerate(self.whisper_decoder.blocks): - self.kv_cache[block.attn.key] = torch.cat([past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1).detach() - self.kv_cache[block.attn.value] = torch.cat([past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1).detach() + self.kv_cache[block.attn.key] = torch.cat( + [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 + ).detach() + self.kv_cache[block.attn.value] = torch.cat( + [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 + ).detach() present_self, present_cross = [], [] # Group self and cross values @@ -79,7 +83,7 @@ def forward( present_self = present_self + present_cross # Add reshape and transpose ops to convert from 3D to 4D - present_self = [present_val.reshape( - present_val.shape[:2] + (-1, 64) - ).transpose(1, 2) for present_val in present_self] + present_self = [ + present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self + ] return logits, present_self diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 1ef6a4329cb28..de0a418ae4daa 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -127,10 +127,7 @@ def __init__(self, model, num_heads, hidden_size, model_impl="hf"): self.attention_mask = AttentionMask(self) if model_impl == "openai": self.attention_fusion = FusionBartAttentionOpenai( - self, - self.hidden_size, - self.num_heads, - self.attention_mask + self, self.hidden_size, self.num_heads, self.attention_mask ) else: self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 9710777d5c105..defb6a854737d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -414,7 +414,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) - if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention) and not isinstance(self.attention_fusion, FusionBartAttentionOpenai): + if ( + options.use_multi_head_attention + and not isinstance(self.attention_fusion, FusionBartAttention) + and not isinstance(self.attention_fusion, FusionBartAttentionOpenai) + ): self.attention_fusion = FusionAttention( self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention )