diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 71801401e9d06..ebecc1db24792 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -74,13 +74,74 @@ def check_runtime_shape_path( return True + def check_runtime_shape_path_openai( + self, + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ): + reshape_qkv_2_path = self.model.match_parent_path( + reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0] + ) + if reshape_qkv_2_path is None: + return False + else: + if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]: + return False + + matmul_qk_path_1 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] + ) + matmul_qk_path_2 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] + ) + if matmul_qk_path_1 is None or matmul_qk_path_2 is None: + return False + + mul_1 = matmul_qk_path_1[0] + mul_2 = matmul_qk_path_2[0] + if mul_1.input[1] != mul_2.input[1]: + return False + if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: + return False + + # For decoder attentions only + if add_qk is not None: + add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) + if add_qk_path is None: + return False + slice_q_path_1 = self.model.match_parent_path( + add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] + ) + slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if slice_q_path_1 is None and slice_q_path_2 is None: + return False + _, unsqueeze_1, _, _ = slice_q_path_1 + unsqueeze_2, _, _ = slice_q_path_2 + if unsqueeze_1.input[0] != unsqueeze_2.input[0]: + return False + if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: + return False + + return True + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Track if fusion is occurring for OpenAI implementation of Whisper + model_impl_openai = False + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0], ) + qkv_nodes_openai = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) if qkv_nodes is not None: ( add_out, @@ -90,6 +151,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): reshape_qkv_1, matmul_qkv, ) = qkv_nodes + elif qkv_nodes_openai is not None: + qkv_nodes = qkv_nodes_openai + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + # Set model implementation to openai + model_impl_openai = True else: return @@ -137,6 +209,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None], ) + v_nodes_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], + ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul matmul_qkv, @@ -149,12 +226,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape"], [1], ) + v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], + ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_v = transpose_v.output[0] + elif v_nodes_openai is not None: + v_nodes = v_nodes_openai + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + + # Find the child path to access the correct present_v values + # Openai impl provides present/past v values in 3D format + # whereas ort MultiHeadAttention expects v values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Add -> Reshape -> Transpose -> Present_V + reshape_path = self.model.match_child_path( + add_v, + ["Reshape", "Transpose"], + exclude=[reshape_v_1], + ) + # For decoder attention types + # add_v_node Reshape <- Transpose <-Past_V + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_V + concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_add_v) = reshape_path + if transpose_add_v.output[0] in graph_output_names: + present_v = transpose_add_v.output[0] + if concat_path is not None: + (concat_v, _, transpose_concat_v) = concat_path + if transpose_concat_v.output[0] in graph_output_names: + present_v = transpose_concat_v.output[0] + concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_v_in = concat_nodes + past_v = transpose_concat_v_in.input[0] elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn @@ -171,6 +288,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + elif ( + v_nodes_with_past_cross_attn_openai is not None + and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn_openai + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return @@ -181,12 +310,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] ) + qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + add_qk = None if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, _, add_qk, _, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 + elif qk_nodes_2_openai is not None: + _, add_qk, matmul_qk = qk_nodes_2_openai + qk_nodes = qk_nodes_2_openai else: return @@ -195,8 +329,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 0, 1], ) + q_nodes_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + reshape_q_2 = None if q_nodes is not None: reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + elif q_nodes_openai is not None: + q_nodes = q_nodes_openai + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return @@ -205,6 +348,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, 1], ) + k_nodes_with_bias_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], + ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], @@ -222,11 +370,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape"], [1, 0], ) + k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias + elif k_nodes_with_bias_openai is not None: + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai + k_nodes = k_nodes_with_bias_openai + present_k = matmul_k.output[0] + + # Find the child path to access the correct present_k values + # Openai impl provides present/past k values in 3D format + # whereas ort MultiHeadAttention expects k values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Matmul -> Reshape -> Transpose -> Present_K + reshape_path = self.model.match_child_path( + matmul_k, + ["Reshape", "Transpose"], + exclude=[reshape_k_1], + ) + # For decoder attention types + # matmul_k_node Reshape <- Transpose <- Past_K + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_K + concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_matmul_k) = reshape_path + if transpose_matmul_k.output[0] in graph_output_names: + present_k = transpose_matmul_k.output[0] + if concat_path is not None: + (concat_k, _, transpose_concat_k) = concat_path + if transpose_concat_k.output[0] in graph_output_names: + present_k = transpose_concat_k.output[0] + concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_k_in = concat_nodes + past_k = transpose_concat_k_in.input[0] elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -249,12 +438,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + elif ( + k_nodes_no_bias_with_past_cross_attn_openai is not None + and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" @@ -270,13 +471,29 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - if not past_k and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, + if ( + model_impl_openai + and not past_k + and not self.check_runtime_shape_path_openai( + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ) + ): + return + elif ( + not model_impl_openai + and not past_k + and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ) ): return @@ -301,8 +518,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) + decoder_attention_with_past = ( + (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v + ) decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e15a12c07bed7..bb697fe1e1506 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -38,6 +38,15 @@ def parse_arguments(argv=None): help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), ) + parser.add_argument( + "--model_impl", + required=False, + default="hf", + choices=["hf", "openai"], + type=str, + help="Select implementation for export of encoder and decoder subgraphs", + ) + parser.add_argument( "--cache_dir", required=False, @@ -300,6 +309,7 @@ def parse_arguments(argv=None): def export_onnx_models( model_name_or_path, + model_impl, cache_dir, output_dir, use_gpu, @@ -321,7 +331,7 @@ def export_onnx_models( device = torch.device("cuda:0" if use_gpu else "cpu") models = WhisperHelper.load_model( - model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path + model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) config = models["decoder"].config @@ -431,6 +441,7 @@ def main(argv=None): output_paths = export_onnx_models( args.model_name_or_path, + args.model_impl, cache_dir, output_dir, args.use_gpu, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index eca5ce3de15d3..0d69960a095ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -18,6 +18,7 @@ from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export from transformers import WhisperConfig, file_utils +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -67,10 +68,13 @@ def forward( class WhisperDecoder(torch.nn.Module): """A Whisper decoder with past key values""" - def __init__(self, decoder, config): + def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None): super().__init__() self.decoder = decoder self.config = config + self.model_impl = model_impl + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) def forward(self, decoder_input_ids, *past): encoder_outputs = file_utils.ModelOutput() @@ -78,6 +82,14 @@ def forward(self, decoder_input_ids, *past): encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states encoder_outputs["hidden_states"] = dummy_encoder_hidden_states encoder_outputs["attentions"] = None + + if self.model_impl == "openai": + dummy_encoder_hidden_states.unsqueeze(0) + dec_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, dummy_encoder_hidden_states, past=past + ) + return dec_out, present + if len(past) == 0: past_key_values = None else: @@ -158,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - encode_sequence_length, + past_decode_sequence_length, head_size, ] @@ -213,7 +225,7 @@ def export_onnx( decoder.config, batch_size=2, encode_sequence_length=3000, - past_decode_sequence_length=5 if isinstance(decoder, WhisperDecoder) else 0, + past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 826d6e42c0775..93281848a5c9c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -25,12 +25,15 @@ class WhisperEncoder(torch.nn.Module): """Whisper encoder outputs only the last hidden state""" - def __init__(self, encoder, config: WhisperConfig): + def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"): super().__init__() self.encoder = encoder self.config = config + self.model_impl = model_impl def forward(self, input_features): + if self.model_impl == "openai": + return self.encoder(input_features) return self.encoder.model.encoder(input_features)[0] @@ -40,7 +43,11 @@ def __init__(self, input_features): @staticmethod def create_dummy( - batch_size: int, sequence_length: int, feature_size: int, device: torch.device, use_int32_inputs: bool + batch_size: int, + sequence_length: int, + feature_size: int, + device: torch.device, + use_int32_inputs: bool = False, ): """Create dummy inputs for Whisper encoder. @@ -61,9 +68,9 @@ def create_dummy( return WhisperEncoderInputs(input_features) def to_list(self) -> List: - if self.input_features is None: + if self.input_ids is None: return [] - return [self.input_features] + return [self.input_ids] class WhisperEncoderHelper: @@ -74,6 +81,7 @@ def export_onnx( onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export encoder to ONNX @@ -90,6 +98,7 @@ def export_onnx( sequence_length=3000, feature_size=config.num_mel_bins, device=device, + use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index a145178dbf37e..351173f525727 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- +import copy import logging import os import tempfile @@ -19,6 +20,7 @@ from transformers import WhisperConfig from whisper_decoder import WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderInputs +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -34,11 +36,16 @@ def __init__( decoder: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, + model_impl: str = "hf", + model: torch.nn.Module = None, ): super().__init__() self.config = config - self.whisper_encoder = WhisperEncoder(encoder, config) + self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl) self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id) + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) + self.model_impl = model_impl def forward( self, @@ -47,9 +54,14 @@ def forward( ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) - decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) - present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) - present = present_self + present_cross + if self.model_impl == "openai": + encoder_hidden_states.unsqueeze(0) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + return decinit_out, encoder_hidden_states, present + else: + decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) + present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) + present = present_self + present_cross return decinit_out[0], encoder_hidden_states, present @@ -72,7 +84,6 @@ def create_dummy( sequence_length=3000, feature_size=config.num_mel_bins, device=device, - use_int32_inputs=use_int32_inputs, ) decoder_input_ids = None if use_decoder_input_ids: @@ -120,7 +131,9 @@ def export_onnx( ) input_list = inputs.to_list() - out = model(inputs.encoder_input_ids, inputs.decoder_input_ids) + # TODO : Investigate whether copy of model if needed + cloned_model = copy.deepcopy(model).to(device) + out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f65060f4f93d3..e2dc79ca247ce 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -72,9 +72,49 @@ def get_onnx_path( directory = os.path.join(output_dir, model_name) if new_folder else output_dir return os.path.join(directory, model_name + ".onnx") + @staticmethod + def load_model_openai( + model_name_or_path: str, + cache_dir: str, + device: torch.device, + ) -> torch.nn.Module: + """Load model given a pretrained name or path, then build models for ONNX conversion. + + Args: + model_name_or_path (str): pretrained model name or path + cache_dir (str): cache directory + device (torch.device): device to run the model + merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True. + Returns: + Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. + """ + from whisper import _ALIGNMENT_HEADS, _MODELS, _download + from whisper.model import ModelDimensions, Whisper + + in_memory = False + + model_name = model_name_or_path.split("/")[-1][8:] + checkpoint_file, alignment_heads = None, None + if model_name in _MODELS: + checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) + alignment_heads = _ALIGNMENT_HEADS[model_name] + + with open(checkpoint_file, "rb") as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) + @staticmethod def load_model( model_name_or_path: str, + model_impl: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True, @@ -94,18 +134,29 @@ def load_model( if version.parse(transformers_version) >= version.parse("4.36.0"): extra_kwargs["attn_implementation"] = "eager" model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) + + if model_impl == "openai": + openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device) + model_encoder, model_decoder = openai_model.encoder, openai_model.decoder + passed_model = openai_model + else: + model_encoder, model_decoder = model, model + passed_model = None + if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) - decoder = WhisperDecoder(model, model.config) + decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model) decoder.eval().to(device) if merge_encoder_and_decoder_init: encoder_decoder_init = WhisperEncoderDecoderInit( - model, - model, + model_encoder, + model_decoder, model.config, decoder_start_token_id=None, + model_impl=model_impl, + model=passed_model, ) return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py new file mode 100644 index 0000000000000..941f61cf7cc29 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class WhisperDecoderInitOpenai(torch.nn.Module): + """WhisperDecoderInit for Openai.""" + + def __init__( + self, + model: torch.nn.Module, + decoder: torch.nn.Module, + ): + super().__init__() + self.whisper_model = model + self.whisper_decoder = decoder + self.kv_cache = {} + + @torch.no_grad() + def forward( + self, + tokens, + audio_features, + past=None, + ): + # Create a kv_cache for past_values + past_kv_cache = dict() + if past is not None: + # Convert past values from 4D to 3D + past = [torch.transpose(val, 1, 2) for val in past] + past = [val.reshape(val.shape[:2] + (-1,)) for val in past] + half_idx = len(past) // 2 + for idx, block in enumerate(self.whisper_decoder.blocks): + past_kv_cache[block.attn.key] = past[2 * idx] + past_kv_cache[block.attn.value] = past[2 * idx + 1] + past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] + past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + + if not self.kv_cache: + self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + + logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + + # Add concat node for past values + if past is not None: + for block in self.whisper_decoder.blocks: + self.kv_cache[block.attn.key] = torch.cat( + [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 + ).detach() + self.kv_cache[block.attn.value] = torch.cat( + [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 + ).detach() + + present_self, present_cross = [], [] + # Group self and cross values + for block in self.whisper_decoder.blocks: + present_self.append(self.kv_cache[block.attn.key]) + present_self.append(self.kv_cache[block.attn.value]) + if past is None: + present_cross.append(self.kv_cache[block.cross_attn.key]) + present_cross.append(self.kv_cache[block.cross_attn.value]) + + present_self = present_self + present_cross + # Add reshape and transpose ops to convert from 3D to 4D + present_self = [ + present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self + ] + return logits, present_self diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index e9029f4f620f8..a8fc6e661933e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -430,6 +430,54 @@ def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, r return None + def match_child_path( + self, + node, + child_op_types, + child_output_index=None, + return_indice=None, + exclude=[], # noqa: B006 + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + child_op_types (str): constraint of child node op_type of each input edge. + child_output_index (list): constraint of input index of each input edge. None means no constraint. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + children: a list of matched children node. + """ + if child_output_index is not None: + assert len(child_output_index) == len(child_op_types) + + current_node = node + matched_children = [] + for i, op_type in enumerate(child_op_types): + matched_child = None + node_children = self.get_children(current_node) + for child_i, child in enumerate(node_children): + if child.op_type == op_type and child not in exclude: + if child_output_index is not None and child_output_index[i] != child_i: + logger.debug( + f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", + stack_info=True, + ) + return None + matched_child = child + if matched_child is None: + logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + return None + + matched_children.append(matched_child) + current_node = matched_child + return matched_children + def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): if output_name_to_node is None: output_name_to_node = self.output_name_to_node() diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 2a48722d17a19..61a786d7af60b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -121,7 +121,7 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): class BartOnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c9db1fbc02931..40ea8cf774918 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -361,7 +361,8 @@ def run_configs(self, optional_arguments): # INT8 CPU arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments - self.run_export(arguments) + if "--model_impl" not in arguments: + self.run_export(arguments) @pytest.mark.slow def test_required_args(self): @@ -393,6 +394,11 @@ def test_cross_qk_overall(self): ] self.run_configs(decoder_input_ids) + @pytest.mark.slow + def test_openai_impl_whisper(self): + optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + self.run_configs(optional_args) + if __name__ == "__main__": unittest.main()