diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index ae23777fa55ec..c1a776a9a1b8e 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -15,6 +15,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + std::unordered_set attn_ln_nodes; + std::unordered_set attn_add_nodes; + for (auto node_index : node_topology_list) { auto* node_ptr = graph.GetNode(node_index); if (!node_ptr) @@ -74,6 +77,38 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, std::vector shape_values; int64_t m = 0, k = 0, n = 0; if (need_reshape) { + // Skip Attention pattern, AttentionFusion will handle it. In such case, there are 4 MatMul-Add pairs, + // 3 of them are following LN, the other one produces output which is added by LN's output. + const Node* parent_node = graph.GetProducerNode(matmul_input_defs[0]->Name()); + if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&next_node) > 0) { + continue; + } + if (parent_node && parent_node->OpType() == "LayerNormalization") { + unsigned int add_count = 0; + unsigned int matmul_count = 0; + unsigned int shape_count = 0; + const Node* ln_add_node = nullptr; + for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) { + std::string op_type = (*it).OpType(); + if (op_type == "Add") { + ln_add_node = &(*it); + add_count++; + } else if (op_type == "MatMul") { + matmul_count++; + } else if (op_type == "Shape") { + shape_count++; + } + } + if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) { + const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[0]->Name()); + if (attn_add_node && attn_add_node->OpType() == "Add") { + attn_ln_nodes.insert(parent_node); + attn_add_nodes.insert(attn_add_node); + continue; + } + } + } + // Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require // both inputs have concrete shape for now, we can add dynamic shape support in future. bool is_concrete_shape = true; @@ -121,10 +156,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, }; bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) || - (bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) && bias_shape.dim(1) == dim_n) || - (bias_shape.dim_size() == 2 && - ((!need_reshape && bias_shape.dim(0) == matmul_a_shape->dim(0)) || - (need_reshape && utils::HasDimValue(bias_shape.dim(0)) && bias_shape.dim(0).dim_value() == m)) && + (!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) && + bias_shape.dim(1) == dim_n) || + (!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) && (dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n))); if (!valid) { continue; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 4221d2eaaae68..6448961df9331 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2414,10 +2414,7 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_three_input) { ASSERT_TRUE(op_to_count["Gemm"] == 1); } -// Matmul+Add with shape [k]*[k,N]+[N], won't do the fusion -// We can do the fusion by changing shape to [1,k]*[k,N]+[1,N], then add a reshape [1,N]=>[N] -// This will bring extra cost. And there's only very limited gain to fuse Matmul+Add to Gemm -// Since the basic implementation is almost same +// Matmul+Add with concrete shape [k]*[k,N]+[N], will fuse to Reshape nodes and Gemm. TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/neg_model.onnx"; @@ -2430,9 +2427,9 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["MatMul"] == 1); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["Gemm"] == 0); + ASSERT_TRUE(op_to_count["MatMul"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Gemm"] == 1); } // Matmul+Add with shape [M,k]*[k,N]+[1,4], won't do the fusion @@ -2506,7 +2503,7 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput({{8, 16, 32}}); auto* weight_arg = builder.MakeInput({{32, 768}}); - auto* bias_arg = builder.MakeInput({{1, 768}}); + auto* bias_arg = builder.MakeInput({{768}}); auto* matmul_out = builder.MakeIntermediate(); auto* output_arg = builder.MakeOutput(); builder.AddNode("MatMul", {input_arg, weight_arg}, {matmul_out});