From ae8561979f494029c863dafb67bae05639ebff60 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 24 Oct 2023 19:41:10 -0700 Subject: [PATCH] Introduce new optimizer MatMul + BatchNormalization (#17915) ### Description Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul + BatchNormalization node. This optimizer look for a specific pattern observed in one of the impacting customer models and fuse the Matmul and Batchnormalization node into a Gemm node. For details on the pattern matching and fusion please refer to the comment section of `matmul_bn_fusion.cc`. To visualize, this optimizer will replace following subgraph to a Gemm node.
               MatMul                  GEMM
                 |                       |
              Reshape ^     --->      Reshape ^
                 |                       |
            Transpose ^             Transpose ^
                 |
       BatchNormalization
Note: ^ means there can be >=0 occurrence(s) of that node.
Few example fusable pattern:
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM ->
Reshape -> Transpose
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM ->
Reshape -> Reshape
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization --->
GEMM -> Reshape -> Transpose -> Reshape
* - MatMul -> BatchNormalization ---> GEMM
Note: This optimizer may evolve in the future to be more generic in terms of the pattern matching. ### Motivation and Context - Why is this change required? What problem does it solve? One of the user of ORT+DML ep needs this to better target the model to DML. But this transformation applies more broadly, so added L1 optimizer. --- .../core/optimizer/graph_transformer_utils.cc | 2 + onnxruntime/core/optimizer/initializer.cc | 28 +- onnxruntime/core/optimizer/initializer.h | 2 +- .../core/optimizer/matmul_bn_fusion.cc | 230 +++++++++++++++ onnxruntime/core/optimizer/matmul_bn_fusion.h | 27 ++ .../test/optimizer/graph_transform_test.cc | 263 ++++++++++++++++++ .../fusion/fuse-matmul-bn-directly.onnx | Bin 0 -> 513 bytes .../fuse-matmul-bn-non-ignorable-node.onnx | Bin 0 -> 593 bytes .../fusion/fuse-matmul-bn-only-reshape.onnx | Bin 0 -> 639 bytes .../fusion/fuse-matmul-bn-only-transpose.onnx | Bin 0 -> 579 bytes .../fusion/fuse-matmul-bn-with-reshape.onnx | Bin 0 -> 709 bytes 11 files changed, 543 insertions(+), 9 deletions(-) create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.h create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c4416068e2457..5a441b1d1701e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +128,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index c8da15f65a6d7..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..e944522c9c338 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..7a43483cf37d4 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6acf631d53cd9..46b95a127b75c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1079,6 +1080,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType() == "BatchNormalization") { + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + // additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + +TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + graph_utils::RemoveNode(graph, node); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa11adaac8d95db4772990bac6f5b0b072f4d5c1 GIT binary patch literal 513 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaQSngN^o%`<;52#C+4Jbu>)C2nTf? zq%5&Whz)9pkW*qwa)w`iQEp;RW>sPd&@n<{FpKkG&I7wutCoY6gH?c0DW&9@y}fs& zzrATnsojapNc+btLhLz8((JXXI_#F_bK0x1?6p(Ye{Q$n%?jIp7hc)TKdWx{F3r&H z+?8o_>{ zSDarY#0wS3FD(JeE3x?|miU(DaA|R&N(k|1rljVTWR_IMLsfEkLIt=&xX>lJIFj<> zi<1*`Qn}cHtfb7uVlX2&H8GEi4JcBUSR}*=q@iXBIVF}PXZYn8u2svy%E2nYsFc!j%G0hwo57}ap`u+&w6Oi|3GViX_qf@! zybQG0I2>y4k;iL)_kNE3_wx$&tYe@Bh!we$Lz!TZ@_d?5;5D+V6{= zV3%pwVCNR)V}GFJiJcSoRJ-K@CU%z=KCp8_4^Ax=u;n5Q3=Q_^*a;peTFMYrb=SLX z+22jE+ccNa&c@Z%zU4-=y%Q|JV2aw9{@H~1McZ4hIcs}X#K69_o&}-^6qs5{5R;0# z3hgU&)9ra*iP{&LMB2BRci1^W13F3^8u>zeTs$0%LL6L79E?EBnk2ym4Oes-Cl)RS G0S*9lCcbw7 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c361a42700a30b12eeeacefe8da3825a1e9b3ab7 GIT binary patch literal 639 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#t`IJC zIWCT*y!hhe#GF(vb|5P$GqD)V$W2Ym<6;AflqD7kaRO%ZT9OGc24NQswKn0BEW9N z`yU1v85vofU}31G0(P1R14Dy7I(C8wnU*p{)fxLfyYog2_W4&F?IlH6?J8~<+B?C5 z6Q;!T*3o{AR=O=~<8iytTpoxbP*7_rK}?#|t7^aKc8I;q6Gyw$c{gn>(pBu8 ypwSQ|4oxCLd|W&nj6xh-OdO0r%$lUY1r2o}F)k6Hi~v*yBnwvI#KOfOzySb5e#_?o literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f70ae2e6229e7a89eb90cb8d68360e35e4c32291 GIT binary patch literal 579 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaM^I7N(c#-6eZ>r7vvYG#zT}EC~=0sgtU0MSPD{$ zavc~Q7#1+Ha|Lr@D(B)z%8M^fPRvQ=Vh6I4G82ozjNH`3JT5k%NLgZ$5F6APLQaV# z$r*n6MY)MNnN^7;K<@~N!JM84^B~x_TD2Uk9IOJ2N-3i2U)s&qRAr3Ak4n`nmO_Jb( RhA@&0D@fLfg^NLe0|4BqxTOF9 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8e4bc49514548c604b36029f780f7ff1a56db148 GIT binary patch literal 709 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@ zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{| zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8{d>;u)k9_-R68Ut6h7pqW!TWKOYnLXR9R84eZ!b}QchFu=&j$m#@( z2`v?{(?l2;8tl=r6FkVYlp(5YTTJaU71!F`j(KX=uE%OG<|u9N1Pe}>qNNu#?VDGY z+kf6~Y%f=BV{d;x(hjT$6x3Qu5R=rr>+N}uHrviU9%8RkaM$M1!XSGmXf#BLLlc@1 o9~Tb?qYwud69*#@vnDBUK|@_gj7tP4BLI_u(u__lTnqvn0In0|n*aa+ literal 0 HcmV?d00001