Skip to content

Commit

Permalink
fix ut failure
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Oct 21, 2024
1 parent ac36682 commit 0a05430
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
42 changes: 38 additions & 4 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;

Check warning on line 19 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:19: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr)
Expand Down Expand Up @@ -74,6 +77,38 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
std::vector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Skip Attention pattern, AttentionFusion will handle it. In such case, there are 4 MatMul-Add pairs,
// 3 of them are following LN, the other one produces output which is added by LN's output.
const Node* parent_node = graph.GetProducerNode(matmul_input_defs[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&next_node) > 0) {
continue;
}
if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
unsigned int shape_count = 0;
const Node* ln_add_node = nullptr;
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
std::string op_type = (*it).OpType();
if (op_type == "Add") {
ln_add_node = &(*it);
add_count++;
} else if (op_type == "MatMul") {
matmul_count++;
} else if (op_type == "Shape") {
shape_count++;
}
}
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[0]->Name());
if (attn_add_node && attn_add_node->OpType() == "Add") {
attn_ln_nodes.insert(parent_node);
attn_add_nodes.insert(attn_add_node);
continue;
}
}
}

// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
bool is_concrete_shape = true;
Expand Down Expand Up @@ -121,10 +156,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
};

bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) ||
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) && bias_shape.dim(1) == dim_n) ||
(bias_shape.dim_size() == 2 &&
((!need_reshape && bias_shape.dim(0) == matmul_a_shape->dim(0)) ||
(need_reshape && utils::HasDimValue(bias_shape.dim(0)) && bias_shape.dim(0).dim_value() == m)) &&
(!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) &&
bias_shape.dim(1) == dim_n) ||
(!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) &&
(dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n)));
if (!valid) {
continue;
Expand Down
13 changes: 5 additions & 8 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2414,10 +2414,7 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_three_input) {
ASSERT_TRUE(op_to_count["Gemm"] == 1);
}

// Matmul+Add with shape [k]*[k,N]+[N], won't do the fusion
// We can do the fusion by changing shape to [1,k]*[k,N]+[1,N], then add a reshape [1,N]=>[N]
// This will bring extra cost. And there's only very limited gain to fuse Matmul+Add to Gemm
// Since the basic implementation is almost same
// Matmul+Add with concrete shape [k]*[k,N]+[N], will fuse to Reshape nodes and Gemm.
TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/neg_model.onnx";

Expand All @@ -2430,9 +2427,9 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) {
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["MatMul"] == 1);
ASSERT_TRUE(op_to_count["Add"] == 1);
ASSERT_TRUE(op_to_count["Gemm"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gemm"] == 1);
}

// Matmul+Add with shape [M,k]*[k,N]+[1,4], won't do the fusion
Expand Down Expand Up @@ -2506,7 +2503,7 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{8, 16, 32}});
auto* weight_arg = builder.MakeInput<float>({{32, 768}});
auto* bias_arg = builder.MakeInput<float>({{1, 768}});
auto* bias_arg = builder.MakeInput<float>({{768}});
auto* matmul_out = builder.MakeIntermediate();
auto* output_arg = builder.MakeOutput();
builder.AddNode("MatMul", {input_arg, weight_arg}, {matmul_out});
Expand Down

0 comments on commit 0a05430

Please sign in to comment.