diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index cdfb2139730ad..345ef2b504aa4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -436,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -549,6 +552,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -701,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c1b241aa1a5ec..d11cb91d98b0c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -657,7 +657,6 @@ def create_multihead_attention_node( return None graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) mha_node_name = self.model.create_node_name("Attention") # Add initial Q/K/V inputs for MHA @@ -693,12 +692,15 @@ def create_multihead_attention_node( mha_inputs.append("") # Add optional inputs for MHA - if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names: + + if past_k and past_v: mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v]) + elif key_padding_mask or add_qk: + mha_inputs.extend([key_padding_mask, add_qk]) # Add outputs for MHA mha_outputs = [output] - if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names: + if present_k and present_v: mha_outputs.extend([present_k, present_v]) mha_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py new file mode 100644 index 0000000000000..6bc681c57444e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -0,0 +1,143 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionConformerAttention(FusionAttention): + """ + Fuse Conformer Attention subgraph into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + _, + _, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qkv path") + return + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, 1], + ) + + add_v = None + if v_nodes is not None: + (concat_v, _, _, add_v, matmul_v) = v_nodes + concat_parent = self.model.get_parent(concat_v, 0, None) + present_v = concat_v.output[0] + past_v = concat_parent.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + + if qk_nodes is not None: + _, add_qk, matmul_qk = qk_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + if q_nodes is not None: + _, _, reshape_q, add_q, matmul_q = q_nodes + else: + logger.debug("fuse_conformer_attention: failed to match q path") + return + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, 1], + ) + + matmul_k = None + if k_nodes is not None: + _, concat_k, _, _, add_k, matmul_k = k_nodes + concat_parent = self.model.get_parent(concat_k, 0, None) + past_k = concat_parent.output[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match k path") + return + + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size") + return + + new_node = self.create_multihead_attention_node( + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + attention_last_node.output[0], + add_qk=add_qk.input[1], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + + if new_node is None: + logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/onnx_model_conformer.py b/onnxruntime/python/tools/transformers/onnx_model_conformer.py new file mode 100644 index 0000000000000..1506d85f53fd4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_conformer.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional + +from fusion_attention import AttentionMask +from fusion_conformer_attention import FusionConformerAttention +from fusion_options import FusionOptions +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class ConformerOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention + self.attention_fusion.disable_multi_head_attention_bias = ( + False if options is None else options.disable_multi_head_attention_bias + ) + super().optimize(options, add_dynamic_axes) + + def fuse_attention(self): + self.attention_fusion.apply() + + def preprocess(self): + self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 94a757320e598..6842a97fe0c77 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -32,6 +32,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_clip import ClipOnnxModel +from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel @@ -56,6 +57,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "conformer": (ConformerOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py new file mode 100644 index 0000000000000..71e4f2b63cf4f --- /dev/null +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -0,0 +1,543 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import List + +import numpy as np +import onnx +from bert_model_generator import float_tensor +from onnx import TensorProto, helper, numpy_helper + + +# Adapted from bert_model_generator.py +def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = ( + [np.random.uniform(low, high) for _ in range(total_elements)] + if random + else [0.0] * total_elements + if zeros + else [1.0] * total_elements + ) + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights + + +def create_conformer_attention( + hidden_size=512, + num_heads=8, + epsilon=0.000009999999747378752, + add_before_layernorm=False, + fused=False, +): + # Get head size and ensure head size is an integer + assert hidden_size % num_heads == 0 + head_size = hidden_size // num_heads + + # Construct input and output nodes + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + ] + outputs = [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + ] + nodes = [] + + # Create layernorm (Add + LayerNorm or SkipLayerNorm) + if add_before_layernorm: + nodes.extend( + [ + helper.make_node( + "Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" + ), + helper.make_node( + "LayerNormalization", + ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul"], + "layernorm", + epsilon=epsilon, + ), + ] + ) + else: + nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["input_0", "input_1", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"], + "skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + if fused: + fused_q_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "q_weight"], + ["q_matmul_output"], + "q_path_matmul", + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node( + "Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0 + ), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + nodes.extend(fused_q_nodes) + nodes.extend( + [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node( + "Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0 + ), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + allowzero=0, + ), + helper.make_node( + "MultiHeadAttention", + [ + "q_matmul_output", + "k_matmul_output", + "v_matmul_output", + "Attention_0_qkv_bias", + "", + "reshape_position_emb", + "gather_past_k_output", + "gather_past_v_output", + ], + ["attn_output", "oup_cache_k", "oup_cache_v"], + "Attention_0", + domain="com.microsoft", + num_heads=num_heads, + ), + ] + ) + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + else: + # Create nodes for Q/K/V paths + q_nodes = [ + helper.make_node( + "MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + k_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), + helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"), + helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_k_output", "k_4d_bnsh"], + ["oup_cache_k"], + "concat_past_k_and_curr_k", + axis=2, + ), + helper.make_node( + "Transpose", + ["oup_cache_k"], + ["k_output_transpose"], + "k_transpose_last_two_dims", + perm=[0, 1, 3, 2], + ), + ] + v_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), + helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"), + helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_v_output", "v_4d_bnsh"], + ["oup_cache_v"], + "concat_past_v_and_curr_v", + axis=2, + ), + ] + pos_embed = [ + helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + ), + ] + nodes.extend(q_nodes) + nodes.extend(k_nodes) + nodes.extend(v_nodes) + nodes.extend(pos_embed) + + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["q_attn_heads_output"], + "q_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["v_attn_heads_output"], + "v_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["bsd_format"], + axis=0, + ), + helper.make_node( + "Constant", + inputs=[], + outputs=["q_bsnh_reshape"], + value=numpy_helper.from_array( + np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" + ), + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + + # Compute Q x K' + nodes.extend( + [ + helper.make_node( + "MatMul", + [ + "q_div_output", + "k_output_transpose", + ], + ["qk_output"], + "matmul_qk", + ) + ] + ) + + # Create nodes for computing softmax(Q x K') x V + nodes.extend( + [ + helper.make_node( + "Add", + [ + "qk_output", + "reshape_position_emb", + ], + ["add_qk_output"], + "add_qk", + ), + helper.make_node( + "Softmax", + ["add_qk_output"], + ["softmax_output"], + "softmax_qk", + axis=2, + ), + helper.make_node( + "MatMul", + ["softmax_output", "oup_cache_v"], + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + "matmul_qkv", + ), + helper.make_node( + "Transpose", + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + ["qkv_bsnh"], + "transpose_bnsh_to_bsnh", + perm=[0, 2, 1, 3], + ), + helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), + ] + ) + + # Create final nodes to conclude attention + nodes.append( + helper.make_node( + "MatMul", + ["attn_output", "matmul_after_attn_initializer"], + ["matmul_after_attn_output"], + "matmul_after_attn", + ), + ) + if not fused: + next_sln_inputs = [ + "layernorm_add_output_to_skiplayernorm", + "add_after_attn_output", + "layernorm_weight", + "layernorm_bias", + ] + nodes.extend( + [ + helper.make_node( + "Add", + ["add_after_attn_initializer", "matmul_after_attn_output"], + ["add_after_attn_output"], + "add_after_attn", + ), + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "next_skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ), + ] + ) + else: + next_sln_inputs = [ + "matmul_after_attn_output", + "layernorm_add_output_to_skiplayernorm", + "layernorm_weight", + "layernorm_bias", + "add_after_attn_initializer", + ] + nodes.append( + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "SkipLayerNorm_AddBias_0", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + # Create initializers + v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) + v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) + q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) + q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) + k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) + k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size]) + + qkv_bias = helper.make_tensor( + "Attention_0_qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + q_bias_data + k_bias_data + v_bias_data, + ) + initializers = [ + float_tensor("layernorm_weight", [hidden_size]), + float_tensor("layernorm_bias", [hidden_size]), + float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), + float_tensor("add_after_attn_initializer", [hidden_size]), + ] + + # Add Q/K/V weight tensors as initializers + if fused: + initializers.extend([q_weight, k_weight, v_weight]) + initializers.extend([q_bias]) + initializers.append(qkv_bias) + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"), + ] + ) + else: + initializers.extend([q_weight, k_weight, v_weight]) + + initializers.extend([q_bias, k_bias, v_bias]) + + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + ] + ) + + # Construct graph + graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + +if __name__ == "__main__": + np.random.seed(2) + num_heads = 8 + hidden_size = 512 + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size) + onnx.save(model, "conformer_self_mha.onnx") + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) + onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx") diff --git a/onnxruntime/test/python/transformers/test_conformer.py b/onnxruntime/test/python/transformers/test_conformer.py new file mode 100644 index 0000000000000..471ba9756bcf8 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_conformer.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +from conformer_model_generator import create_conformer_attention +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join( + os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename + ) + print("Expected model path = ", expected_model_path) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + print("Expected initializer initial = ", expected_initializer.name) + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) + + def test_ct_mha_fusion(self): + num_heads = 8 + hidden_size = 512 + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False) + dir = "." + model_path = os.path.join(dir, "conformer_self_mha.onnx") + onnx.save(model, model_path) + options = FusionOptions("conformer") + optimized_model = optimize_model( + model_path, + model_type="conformer", + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=options, + ) + os.remove(model_path) + self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx new file mode 100644 index 0000000000000..9d882751db265 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx differ