diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c4416068e2457..5a441b1d1701e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +128,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index c8da15f65a6d7..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..e944522c9c338 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..7a43483cf37d4 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6acf631d53cd9..46b95a127b75c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1079,6 +1080,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType() == "BatchNormalization") { + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + // additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + +TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + graph_utils::RemoveNode(graph, node); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx new file mode 100644 index 0000000000000..fa11adaac8d95 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx new file mode 100644 index 0000000000000..1050a7285b4a6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx new file mode 100644 index 0000000000000..c361a42700a30 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx new file mode 100644 index 0000000000000..f70ae2e6229e7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx new file mode 100644 index 0000000000000..8e4bc49514548 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx differ