diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..b7b78189d54cd 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -443,6 +443,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -556,6 +559,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -711,6 +715,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_whisper} + $/transformers/test_data/models/conformer/ ) endif() diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 65d5efa6dd150..6bc681c57444e 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -12,7 +12,7 @@ class FusionConformerAttention(FusionAttention): """ - Fuse Conformer Attention subgraph into one Attention node. + Fuse Conformer Attention subgraph into one MultiHeadAttention node. """ def __init__( @@ -40,6 +40,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qkv, ) = qkv_nodes else: + logger.debug("fuse_conformer_attention: failed to match qkv path") return v_nodes = self.model.match_parent_path( @@ -55,7 +56,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): present_v = concat_v.output[0] past_v = concat_parent.output[0] else: - logger.debug("fuse_attention: failed to match v path") + logger.debug("fuse_conformer_attention: failed to match v path") return qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) @@ -63,6 +64,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if qk_nodes is not None: _, add_qk, matmul_qk = qk_nodes else: + logger.debug("fuse_conformer_attention: failed to match qk path") return q_nodes = self.model.match_parent_path( @@ -73,6 +75,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if q_nodes is not None: _, _, reshape_q, add_q, matmul_q = q_nodes else: + logger.debug("fuse_conformer_attention: failed to match q path") return k_nodes = self.model.match_parent_path( @@ -88,16 +91,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_k = concat_parent.output[0] present_k = concat_k.output[0] else: + logger.debug("fuse_conformer_attention: failed to match k path") return attention_last_node = reshape_qkv num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: - logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size") return - new_node = None new_node = self.create_multihead_attention_node( matmul_q, matmul_k, @@ -116,6 +119,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) if new_node is None: + logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") return self.nodes_to_add.append(new_node) diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py index b40bdaed1e1bf..71e4f2b63cf4f 100644 --- a/onnxruntime/test/python/transformers/conformer_model_generator.py +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -44,13 +44,12 @@ def create_conformer_attention( inputs = [ helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]), helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]), - helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, ["batch_size", 8, 72, head_size]), - helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, ["batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), ] outputs = [ helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]), helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]), - helper.make_tensor_value_info("pos_k_output", TensorProto.FLOAT, ["batch_size", 8, 8, 80]), helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), ] @@ -85,43 +84,73 @@ def create_conformer_attention( ) if fused: + fused_q_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "q_weight"], + ["q_matmul_output"], + "q_path_matmul", + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node( + "Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0 + ), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + nodes.extend(fused_q_nodes) nodes.extend( [ helper.make_node( "MatMul", - ["layernorm_add_output_to_matmul", "MatMul_q_weight"], - ["MatMul_q_out"], - "MatMul_q", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", ), helper.make_node( "MatMul", - ["layernorm_add_output_to_matmul", "MatMul_k_weight"], - ["MatMul_k_out"], - "MatMul_k", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node( + "Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0 + ), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] ), helper.make_node( "MatMul", - ["layernorm_add_output_to_matmul", "MatMul_v_weight"], - ["MatMul_v_out"], - "MatMul_v", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] ), helper.make_node( "Reshape", - ["layernorm_add_output_to_matmul", "concat_reshape"], - ["pos_k_output"], - "Reshape_out_pos_k", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + allowzero=0, ), helper.make_node( "MultiHeadAttention", [ - "MatMul_q_out", - "MatMul_k_out", - "MatMul_v_out", + "q_matmul_output", + "k_matmul_output", + "v_matmul_output", "Attention_0_qkv_bias", "", - "pos_k_output", - "inp_cache_k", - "inp_cache_v", + "reshape_position_emb", + "gather_past_k_output", + "gather_past_v_output", ], ["attn_output", "oup_cache_k", "oup_cache_v"], "Attention_0", @@ -130,6 +159,52 @@ def create_conformer_attention( ), ] ) + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) else: # Create nodes for Q/K/V paths q_nodes = [ @@ -137,11 +212,11 @@ def create_conformer_attention( "MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" ), helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Reshape", ["q_add_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), + helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"), helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), helper.make_node( "Div", - ["q_4d_bnsh", "q_attn_heads_output"], + ["q_4d_bnsh", "q_scale"], ["q_div_output"], "q_div_by_sqrt_head_size", ), @@ -149,16 +224,16 @@ def create_conformer_attention( k_nodes = [ helper.make_node( "MatMul", - ["layernorm_add_output_to_matmul", "q_weight"], + ["layernorm_add_output_to_matmul", "k_weight"], ["k_matmul_output"], "k_path_matmul", ), helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), + helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"), helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), helper.make_node( "Concat", - ["inp_cache_k", "k_4d_bnsh"], + ["gather_past_k_output", "k_4d_bnsh"], ["oup_cache_k"], "concat_past_k_and_curr_k", axis=2, @@ -179,33 +254,46 @@ def create_conformer_attention( "v_path_matmul", ), helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), + helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"), helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), helper.make_node( "Concat", - ["inp_cache_v", "v_4d_bnsh"], + ["gather_past_v_output", "v_4d_bnsh"], ["oup_cache_v"], "concat_past_v_and_curr_v", axis=2, ), ] - pos_k_reshape_node = [ + pos_embed = [ + helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), helper.make_node( "Reshape", - ["layernorm_add_output_to_matmul", "pos_k_concat"], - ["pos_k_output"], - "Reshape_out_pos_k", - ) + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + ), ] nodes.extend(q_nodes) nodes.extend(k_nodes) nodes.extend(v_nodes) - nodes.extend(pos_k_reshape_node) + nodes.extend(pos_embed) # Create nodes used with qkv concats, reshapes, and transposes nodes.extend( [ - helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape"), + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), helper.make_node( "Mul", @@ -216,84 +304,64 @@ def create_conformer_attention( helper.make_node( "Unsqueeze", ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_attn_heads_output"], "unsqueeze_num_heads", ), helper.make_node( "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], ["q_attn_heads_output"], "q_num_heads", axis=0, ), helper.make_node( "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], ["k_attn_heads_output"], "k_num_heads", axis=0, ), helper.make_node( "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], ["v_attn_heads_output"], "v_num_heads", axis=0, ), helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["bsd_format"], + axis=0, ), helper.make_node( "Constant", inputs=[], - outputs=["kv_bsnh_reshape"], + outputs=["q_bsnh_reshape"], value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" + np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" ), ), - helper.make_node( - "Concat", - ["input_0"], - ["concat_pos_k"], - "pos_k_concat", - axis=0, - ), ] ) - # Create nodes used with Q x K' and softmax(Q x K') x V nodes.extend( [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", - ["gather_0_output", "unsqueeze_axes_input"], - ["unsqueeze_0_output"], - "unsqueeze_0", - ), - helper.make_node( - "Unsqueeze", - ["gather_1_output", "unsqueeze_axes_input"], - ["unsqueeze_1_output"], - "unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), ] ) @@ -319,7 +387,7 @@ def create_conformer_attention( "Add", [ "qk_output", - "pos_k_output", + "reshape_position_emb", ], ["add_qk_output"], "add_qk", @@ -338,12 +406,12 @@ def create_conformer_attention( "matmul_qkv", ), helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", + "Transpose", + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + ["qkv_bsnh"], + "transpose_bnsh_to_bsnh", + perm=[0, 2, 1, 3], ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), ] ) @@ -408,24 +476,7 @@ def create_conformer_attention( q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size]) - matmul_q_weight = helper.make_tensor( - "MatMul_q_weight", - TensorProto.FLOAT, - [hidden_size, hidden_size], - q_weight_data, - ) - matmul_k_weight = helper.make_tensor( - "MatMul_k_weight", - TensorProto.FLOAT, - [hidden_size, hidden_size], - k_weight_data, - ) - matmul_v_weight = helper.make_tensor( - "MatMul_v_weight", - TensorProto.FLOAT, - [hidden_size, hidden_size], - v_weight_data, - ) + qkv_bias = helper.make_tensor( "Attention_0_qkv_bias", TensorProto.FLOAT, @@ -437,13 +488,24 @@ def create_conformer_attention( float_tensor("layernorm_bias", [hidden_size]), float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), float_tensor("add_after_attn_initializer", [hidden_size]), - helper.make_tensor("concat_reshape", TensorProto.INT64, [4], [1, 8, 8, 80]), ] # Add Q/K/V weight tensors as initializers if fused: - initializers.extend([matmul_q_weight, matmul_k_weight, matmul_v_weight]) + initializers.extend([q_weight, k_weight, v_weight]) + initializers.extend([q_bias]) initializers.append(qkv_bias) + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"), + ] + ) else: initializers.extend([q_weight, k_weight, v_weight]) @@ -464,7 +526,7 @@ def create_conformer_attention( ) # Construct graph - graph = helper.make_graph(nodes, "ct_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") + graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) return helper.make_model(graph, opset_imports=(opsetid,)) diff --git a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx index c3c4b6a7d5a67..9d882751db265 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx differ