From 8a5f299a130c0b2b82fbe61f215e23daace68b15 Mon Sep 17 00:00:00 2001
From: Sumit Agarwal <sumitagarwal330@gmail.com>
Date: Tue, 24 Oct 2023 19:41:10 -0700
Subject: [PATCH] Introduce new optimizer MatMul + BatchNormalization (#17915)

### Description
Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul
+ BatchNormalization node. This optimizer look for a specific pattern
observed in one of the impacting customer models and fuse the Matmul and
Batchnormalization node into a Gemm node. For details on the pattern
matching and fusion please refer to the comment section of
`matmul_bn_fusion.cc`.

To visualize, this optimizer will replace following subgraph to a Gemm
node.
<pre>
               MatMul                  GEMM
                 |                       |
              Reshape ^     --->      Reshape ^
                 |                       |
            Transpose ^             Transpose ^
                 |
       BatchNormalization
Note: ^ means there can be >=0 occurrence(s) of that node.
Few example fusable pattern:
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM ->
Reshape -> Transpose
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM ->
Reshape -> Reshape
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization --->
GEMM -> Reshape -> Transpose -> Reshape
* - MatMul -> BatchNormalization ---> GEMM
</pre>

Note: This optimizer may evolve in the future to be more generic in
terms of the pattern matching.

### Motivation and Context
- Why is this change required? What problem does it solve?
One of the user of ORT+DML ep needs this to better target the model to
DML. But this transformation applies more broadly, so added L1
optimizer.
<!-- - If it fixes an open issue, please link to the issue here. -->
---
 .../core/optimizer/graph_transformer_utils.cc |   2 +
 onnxruntime/core/optimizer/initializer.cc     |  28 +-
 onnxruntime/core/optimizer/initializer.h      |   2 +-
 .../core/optimizer/matmul_bn_fusion.cc        | 230 +++++++++++++++
 onnxruntime/core/optimizer/matmul_bn_fusion.h |  27 ++
 .../test/optimizer/graph_transform_test.cc    | 263 ++++++++++++++++++
 .../fusion/fuse-matmul-bn-directly.onnx       | Bin 0 -> 513 bytes
 .../fuse-matmul-bn-non-ignorable-node.onnx    | Bin 0 -> 593 bytes
 .../fusion/fuse-matmul-bn-only-reshape.onnx   | Bin 0 -> 639 bytes
 .../fusion/fuse-matmul-bn-only-transpose.onnx | Bin 0 -> 579 bytes
 .../fusion/fuse-matmul-bn-with-reshape.onnx   | Bin 0 -> 709 bytes
 11 files changed, 543 insertions(+), 9 deletions(-)
 create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.cc
 create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.h
 create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx
 create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx
 create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx
 create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx
 create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx

diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 54511aa02a57c..13ffca70bb214 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -50,6 +50,7 @@
 #include "core/optimizer/matmul_integer_to_float.h"
 #include "core/optimizer/matmul_scale_fusion.h"
 #include "core/optimizer/matmul_transpose_fusion.h"
+#include "core/optimizer/matmul_bn_fusion.h"
 #include "core/optimizer/nchwc_transformer.h"
 #include "core/optimizer/noop_elimination.h"
 #include "core/optimizer/not_where_fusion.h"
@@ -127,6 +128,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
       rules.push_back(std::make_unique<ConvAddFusion>());
       rules.push_back(std::make_unique<ConvMulFusion>());
       rules.push_back(std::make_unique<ConvBNFusion>());
+      rules.push_back(std::make_unique<MatmulBNFusion>());
       rules.push_back(std::make_unique<ClipQuantFusion>());
       rules.push_back(std::make_unique<ReluQuantFusion>());
       break;
diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc
index 9cdc0d9ef0473..5d09971ad8100 100644
--- a/onnxruntime/core/optimizer/initializer.cc
+++ b/onnxruntime/core/optimizer/initializer.cc
@@ -289,7 +289,11 @@ Initializer& Initializer::sqrt() {
 namespace {
 template <typename T>
 struct ScaleByAxis {
-  void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const {
+  void operator()(Tensor& data,
+                  const Tensor& scalers,
+                  const size_t block_size,
+                  const size_t num_blocks,
+                  const bool column_major) const {
     ToNumeric<T> to_numeric;
     const auto scaler_size = scalers.Shape().Size();
     T* dst = data.MutableData<T>();
@@ -301,24 +305,32 @@ struct ScaleByAxis {
       }
     } else {
       for (size_t block_offset = 0, i = 0; i < num_blocks; i++) {
-        const auto numeric_scaler = to_numeric(scalers_data[i]);
-        for (size_t j = 0; j < block_size; ++j, ++block_offset) {
-          dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
+        if (column_major) {
+          for (size_t j = 0; j < block_size; ++j, ++block_offset) {
+            const auto numeric_scaler = to_numeric(scalers_data[j]);
+            dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
+          }
+        } else {
+          const auto numeric_scaler = to_numeric(scalers_data[i]);
+          for (size_t j = 0; j < block_size; ++j, ++block_offset) {
+            dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
+          }
         }
       }
     }
   }
 };
-
 }  // namespace
 
-void Initializer::scale_by_axis(const Initializer& scalers, int axis) {
+void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) {
   ORT_ENFORCE(axis >= 0, "Axis must be non-negative");
   const size_t block_size = narrow<size_t>(data_.Shape().SizeFromDimension(gsl::narrow_cast<size_t>(axis)));
   const size_t num_blocks = size() / block_size;
-  ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size");
+  ORT_ENFORCE(scalers.size() == 1 ||
+                  (column_major ? scalers.size() == block_size : scalers.size() == num_blocks),
+              "Invalid other(scalers) size");
   utils::MLTypeCallDispatcher<MLFloat16, BFloat16, float, double, int32_t, int64_t> t_disp(data_.GetElementType());
-  t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks);
+  t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks, column_major);
 }
 #endif  // ORT_EXTENDED_MINIMAL_BUILD
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h
index dfe054ba1aced..78e3fd6a3d24e 100644
--- a/onnxruntime/core/optimizer/initializer.h
+++ b/onnxruntime/core/optimizer/initializer.h
@@ -86,7 +86,7 @@ class Initializer final {
 
   Initializer& sqrt();
 
-  void scale_by_axis(const Initializer& other, int axis);
+  void scale_by_axis(const Initializer& other, int axis, bool column_major = false);
 #endif  // ORT_EXTENDED_MINIMAL_BUILD
  private:
   std::string name_;
diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc
new file mode 100644
index 0000000000000..e944522c9c338
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc
@@ -0,0 +1,230 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/matmul_bn_fusion.h"
+#include "core/graph/graph_utils.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/utils.h"
+
+namespace onnxruntime {
+
+namespace {
+const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
+    {"Reshape", {1, 5, 13, 14, 19}},
+    {"Transpose", {1, 13}}};
+const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}};
+}  // namespace
+
+bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
+  const Node* curr_node = graph.GetNode(curr_node_index);
+
+  // curr_node has different execution provider then it's parent or
+  // has output edge != 1 (this condition will handle the case when ignorable node
+  // is graph output i.e. a graph like this "MatMul->Transpose")
+  if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() ||
+      curr_node->GetOutputEdgesCount() != 1) {
+    return false;
+  }
+
+  // curr_node can be any of the ignorable_nodes.
+  for (size_t index = 0; index < ignorable_nodes.size(); index++) {
+    if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+std::optional<NodeIndex> MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
+  while (NodeIsIgnorable(graph, root_node, curr_node_index)) {
+    curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index();
+  }
+
+  // curr_node is neither ignorable nor dest
+  const Node* curr_node = graph.GetNode(curr_node_index);
+  if (curr_node->OpType() != dest.first) {
+    return std::nullopt;
+  }
+
+  if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() &&
+      graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) {
+    return curr_node_index;
+  }
+
+  // either curr_node has different execution provider or
+  // has invalid opset.
+  return std::nullopt;
+}
+
+/*
+ *   Given a MatMul node, it will verify the following pattern.
+ *                MatMul                  GEMM
+ *                  |                       |
+ *               Reshape ^     --->      Reshape ^
+ *                  |                       |
+ *             Transpose ^             Transpose ^
+ *                  |
+ *        BatchNormalization
+ * Note: ^ means there can be 0 or any occurrences of that node.
+ * Few example fusable pattern:
+ *  - MatMul -> Reshape -> Transpose -> BatchNormalization              ---> GEMM -> Reshape -> Transpose
+ *  - MatMul -> Reshape -> BatchNormalization                           ---> GEMM -> Reshape
+ *  - MatMul -> Transpose -> BatchNormalization                         ---> GEMM -> Transpose
+ *  - MatMul -> Reshape -> Reshape -> BatchNormalization                ---> GEMM -> Reshape -> Reshape
+ *  - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization   ---> GEMM -> Reshape -> Transpose -> Reshape
+ *  - MatMul -> BatchNormalization                                      ---> GEMM
+ * Other Conditions:
+ *   - B tensor of MatMul should be constant.
+ *   - scale, B, mean, var tensors of BatchNormalization should be constant.
+ *   - Every node in the path, except the BatchNormalization, should have only 1 output edge.
+ */
+bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
+  if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) ||
+      node.GetOutputEdgesCount() != 1) {
+    return false;
+  }
+
+  if (graph.NodeProducesGraphOutput(node)) {
+    return false;
+  }
+
+  // because <node> is not producing graph output, it means it will have a child node
+  NodeIndex child_node_index = node.OutputNodesBegin()->Index();
+  std::optional<NodeIndex> batch_norm_index = MatchPath(graph, node, child_node_index);
+  if (!batch_norm_index.has_value()) {
+    return false;
+  }
+
+  const Node* batch_norm_node = graph.GetNode(*batch_norm_index);
+
+  // Check that the appropriate inputs to the Matmul and BN nodes are constants.
+  if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
+      !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) ||
+      !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) ||
+      !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) ||
+      !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) {
+    return false;
+  }
+
+  // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
+  const auto& output_defs = batch_norm_node->OutputDefs();
+  if (output_defs.size() > 1) {
+    for (size_t i = 1, end = output_defs.size(); i < end; ++i) {
+      if (output_defs[i] != nullptr && output_defs[i]->Exists()) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+/*
+ * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc]
+ *   Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML
+ * Expanding out the terms:
+ *   Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias
+ * Here,
+ *   [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha`
+ *   [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta`
+ * Output = alpha * Input + beta, Input = B tensor of MatMul.
+ *
+ */
+Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
+  NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index();
+  NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value();
+
+  Node& batch_norm_node = *graph.GetNode(batch_norm_node_index);  // need mutable node, that's why extracting node from graph
+
+  // only perform fusion if epsilon is present and is of float_32 type
+  auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon");
+  if (epsilon_attribute == batch_norm_node.GetAttributes().end() ||
+      epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) {
+    return Status::OK();
+  }
+  const float epsilon = epsilon_attribute->second.f();
+
+  const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name());
+  ORT_ENFORCE(scale_tensor);
+  const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name());
+  ORT_ENFORCE(bias_tensor);
+  const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name());
+  ORT_ENFORCE(mean_tensor);
+  const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name());
+  ORT_ENFORCE(var_tensor);
+  const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name());
+  ORT_ENFORCE(matmul_b_tensor);
+
+  if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) ||
+      !optimizer_utils::IsFloatingPointDataType(*scale_tensor) ||
+      !optimizer_utils::IsFloatingPointDataType(*bias_tensor) ||
+      !optimizer_utils::IsFloatingPointDataType(*mean_tensor) ||
+      !optimizer_utils::IsFloatingPointDataType(*var_tensor) ||
+      scale_tensor->dims_size() != 1 ||
+      bias_tensor->dims_size() != 1 ||
+      mean_tensor->dims_size() != 1 ||
+      var_tensor->dims_size() != 1 ||
+      scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
+      bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
+      mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
+      var_tensor->dims(0) != matmul_b_tensor->dims(1)) {
+    return Status::OK();
+  }
+
+  /*
+   * temp = scale / sqrt(var + epsilon)
+   * output = (temp * Input) - ((temp * mean) + bias)
+   */
+  Initializer scale(*scale_tensor, graph.ModelPath());
+  Initializer bias(*bias_tensor, graph.ModelPath());
+  Initializer mean(*mean_tensor, graph.ModelPath());
+  Initializer var(*var_tensor, graph.ModelPath());
+  Initializer matmul_b(*matmul_b_tensor, graph.ModelPath());
+
+  var.add(epsilon);
+  var.sqrt();
+  scale.div(var);  // this is the temp
+  matmul_b.scale_by_axis(scale, 1, true);
+
+  mean.mul(scale);
+  bias.sub(mean);
+
+  // create B tensorProto for new Gemm node from <matmulB> initializer.
+  ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor);
+  matmul_b.ToProto(new_gemm_b_tensor);
+  const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name());
+  new_gemm_b_tensor.set_name(new_gemm_b_name);
+  NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor);
+
+  // create bias tensorProto for new Gemm node from <bias> initializer.
+  ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor);
+  bias.ToProto(new_gemm_bias_tensor);
+  const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias");
+  new_gemm_bias_tensor.set_name(new_gemm_bias_name);
+  NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor);
+
+  Node& gemm_node = graph.AddNode(
+      graph.GenerateNodeArgName("MatMulBnFusion_Gemm"),
+      "Gemm",
+      "Generated from Matmul BatchNormalization fusion",
+      {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg},
+      matmul_node.MutableOutputDefs(),
+      nullptr,
+      kOnnxDomain);
+
+  // Remove MatMul node.
+  Node* node = graph.GetNode(matmul_node.Index());
+  graph_utils::RemoveNodeOutputEdges(graph, *node);
+  graph.RemoveNode(matmul_node.Index());
+
+  // Delete optional empty output defs.
+  // Delete BatchNormalization node and update the input of the child of BatchNormalization
+  batch_norm_node.MutableOutputDefs().resize(1);
+  NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index();
+  graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node);
+
+  rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
+  return Status::OK();
+}
+}  // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h
new file mode 100644
index 0000000000000..7a43483cf37d4
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/rewrite_rule.h"
+
+namespace onnxruntime {
+/*
+ *   This fusion submerges a BatchNormalization operator to it's super
+ *   precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition()
+ *   is true.
+ */
+class MatmulBNFusion : public RewriteRule {
+ public:
+  MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {}
+
+  std::vector<std::string> TargetOpTypes() const noexcept override {
+    return {"MatMul"};
+  }
+
+ private:
+  bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
+
+  Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
+};
+}  // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index b75d9707d2481..b921bbbc31cd6 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -31,6 +31,7 @@
 #include "core/optimizer/conv_add_act_fusion.h"
 #include "core/optimizer/conv_add_fusion.h"
 #include "core/optimizer/conv_bn_fusion.h"
+#include "core/optimizer/matmul_bn_fusion.h"
 #include "core/optimizer/conv_mul_fusion.h"
 #include "core/optimizer/div_mul_fusion.h"
 #include "core/optimizer/dropout_elimination.h"
@@ -964,6 +965,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) {
   }
 }
 
+TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  std::string expected_output_name;
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "MatMul") {
+      expected_output_name = node.OutputDefs()[0]->Name();
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 0);
+  ASSERT_EQ(op_to_count["MatMul"], 0);
+  ASSERT_EQ(op_to_count["Gemm"], 1);
+
+  for (auto& node : graph.Nodes()) {
+    if (node.OpType() == "Gemm") {
+      ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
+          << "fusion should produce the same output name as the MatMul node";
+    }
+  }
+}
+
+TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  std::string expected_output_name;
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "MatMul") {
+      expected_output_name = node.OutputDefs()[0]->Name();
+    } else if (node.OpType() == "BatchNormalization") {
+      node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr));
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 0);
+  ASSERT_EQ(op_to_count["MatMul"], 0);
+  ASSERT_EQ(op_to_count["Gemm"], 1);
+
+  for (auto& node : graph.Nodes()) {
+    if (node.OpType() == "Gemm") {
+      ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
+          << "fusion should produce the same output name as the MatMul node";
+    }
+  }
+}
+
+// should not fuse
+TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "BatchNormalization") {
+      // additional non-empty output to batchNormalization
+      ONNX_NAMESPACE::TypeProto optional_output_tensor_type;
+      optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType);
+      auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type);
+      node.MutableOutputDefs().push_back(&arg);
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 1);
+  ASSERT_EQ(op_to_count["MatMul"], 1);
+  ASSERT_EQ(op_to_count["Gemm"], 0);
+}
+
+TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  std::string expected_output_name;
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "BatchNormalization") {
+      expected_output_name = node.OutputDefs()[0]->Name();
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 0);
+  ASSERT_EQ(op_to_count["MatMul"], 0);
+  ASSERT_EQ(op_to_count["Gemm"], 1);
+
+  for (auto& node : graph.Nodes()) {
+    if (node.OpType() == "Gemm") {
+      ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
+          << "fusion should produce the same output name as the last node";
+    }
+  }
+}
+
+TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  std::string expected_output_name;
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "MatMul") {
+      expected_output_name = node.OutputDefs()[0]->Name();
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 0);
+  ASSERT_EQ(op_to_count["MatMul"], 0);
+  ASSERT_EQ(op_to_count["Gemm"], 1);
+
+  for (auto& node : graph.Nodes()) {
+    if (node.OpType() == "Gemm") {
+      ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
+          << "fusion should produce the same output name as the MatMul node";
+    }
+  }
+}
+
+TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  std::string expected_output_name;
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "MatMul") {
+      expected_output_name = node.OutputDefs()[0]->Name();
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 0);
+  ASSERT_EQ(op_to_count["MatMul"], 0);
+  ASSERT_EQ(op_to_count["Gemm"], 1);
+
+  for (auto& node : graph.Nodes()) {
+    if (node.OpType() == "Gemm") {
+      ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
+          << "fusion should produce the same output name as the MatMul node";
+    }
+  }
+}
+
+TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  GraphViewer graphViewer(graph);
+  for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
+    auto& node = *graph.GetNode(node_index);
+    if (node.OpType() == "BatchNormalization") {
+      graph_utils::RemoveNode(graph, node);
+    }
+  }
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["MatMul"], 1);
+}
+
+// should not fuse
+TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx";
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
+  ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
+
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["BatchNormalization"], 1);
+  ASSERT_EQ(op_to_count["MatMul"], 1);
+  ASSERT_EQ(op_to_count["Gemm"], 0);
+}
+
 TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) {
   constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx";
 
diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..fa11adaac8d95db4772990bac6f5b0b072f4d5c1
GIT binary patch
literal 513
zcmd<!6cR}<N-W5TPb)3X%+HHYtw_u*$Vs(&z|5u3#hRH{P+G#ppPN{cTbdJ}6kn2>
zSDarY#0wS3FD(JeE3x?|miU(DaQSngN^o%`<;52#C+4Jbu>)C2nTf?<Ms8|i9v2%>
zq%5&Whz)9pkW*qwa)w`iQEp;RW>sPd&@n<{FpKkG&I7wutCoY6gH?c0DW&9@y}fs&
zzrATnsojapNc+btLhLz8((JXXI_#F_bK0x1?6p(Ye{Q$n%?jIp7hc)TKdWx{F3r&H
z+?8<q_tM*J!WPTgYka(IH>o_>{<QII+kIkdZT$<v?JkS<**-htWItn)n7tFa_q9~O
zmWwbjG}xnKC%Bijlp(4L{`|HLQ9EmAkrHn2XjNywe*Q^2Cs=U66iM%#X6MZD$S(5F
z3p<rvIrhpj2KHb@prF!Hf|&GUeWCrfR8@P$uqu0o8=Gwl435}2L4z?$92!|dd|W&n
bj6xh-OdO0r%$g*@1r1?z87CGl1_2HLb}goB

literal 0
HcmV?d00001

diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..1050a7285b4a6e9e13d5be17f20316f9c57a2aac
GIT binary patch
literal 593
zcmd<!6cR}<N-W5TPb)3X%+HHYtw_u*$Vs)@z`~`^#hRH{P+G#ppPN{cTbdJ}6kn2>
zSDarY#0wS3FD(JeE3x?|miU(DaA|R&N(k|1rljVTWR_IMLsfEkLIt=&xX>lJIFj<>
zi<1*`Qn}cHtfb7uVlX2&H8GEi4JcBUSR}*=q@iXBIVF}PXZYn8<tFB2Rwb4IUC$-U
zg-{^GhHj~l7|bJiFz<r>u2svy%E2nYsFc!j%G0hwo57}ap`u+&w6Oi|3GViX_qf@!
zybQG0I2>y4k;iL)_kNE3_wx$&tY<Rq9=}wvE!b>e@Bh!we$Lz!TZ@_d?5;5D+V6{=
zV3%pwVCNR)V}GFJiJcSoRJ-K@CU%z=KCp8_4^Ax=u;n5Q3=Q_^*a;peTFMYrb=SLX
z+22jE+ccNa&c@Z%zU4-=y%Q|JV2aw9{@H~1McZ4hIcs}X#K69_o&}-^6qs5{5R;0#
z3hgU&)9ra*iP{&LMB2BRci1^W13F3^8u>zeTs$0%LL6L79E?EBnk2ym4Oes-Cl)RS
G0S*9lCcbw7

literal 0
HcmV?d00001

diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..c361a42700a30b12eeeacefe8da3825a1e9b3ab7
GIT binary patch
literal 639
zcmd<!6cR}<N-W5TPb)3X%+HHYtw_u*$Vs)j$HJx0#hRH{P+G#ppPN{cTbdJ}6kn2>
zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#t`IJC
zIWCT*y!hhe#GF(vb|5P$GqD)V$W2Ym<6;AflqD7kaRO<m?Ltn8CCM3n`9-;jIhj?7
zB|u+r$#NkS2(dvGa(Jes=9OfYR0@f~Jd_9XHaHBlYB^XrSOplBQWkNm*&ozUu)n3e
z%5LKmXZwb2?)IAIn)c}(cJ}jY<?a8chTH#b_p$vwv%ya4kbs@XE)}~A5)bU6zpB`;
zkLIyk(CuM=jYY)f=(7uUJG6T2&iQlNyUr`L<8k_IJK?>%ZT9OGc24NQswKn0BEW9N
z`yU1v85vofU}31G0(P1R14Dy7I(C8wnU*p{)fxLfyYog2_W4&F?IlH6?J8~<+B?C5
z6Q;<Qr^>!T*3o{AR=O=~<8iytTpoxbP*7_rK}?#|t7^aKc8I;q6Gyw$c{gn>(pBu8
ypwSQ|4oxCLd|W&nj6xh-OdO0r%$lUY1r2o}F)k6Hi~v*yBnwvI#KOfOzySb5e#_?o

literal 0
HcmV?d00001

diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..f70ae2e6229e7a89eb90cb8d68360e35e4c32291
GIT binary patch
literal 579
zcmd<!6cR}<N-W5TPb)3X%+HHYtw_u*$Vs(Y#KNV|#hRH{P+G#ppPN{cTbdJ}6kn2>
zSDarY#0wS3FD(JeE3x?|miU(DaM^I7N(c#-6eZ>r7vvYG#zT}EC~=0sgtU0MSPD{$
zavc~Q7#1+Ha|Lr@D(B)z%8M^fPRvQ=Vh6I4G82ozjNH`3JT5k%NLgZ$5F6APLQaV#
z$r*n6MY)MNnN^7;K<@~N!JM84^B~x_TD2Uk9IOJ2N-3i2U)s&qR<YkQOV4_e$9lV<
z`Kk8&UBdRZ`wQ)*?`qmc@71z*c_eJFJMn;BNMMiM&a3t|XV?wxx1{vgi5!l!bHA2l
z|K+`hy?m3F{ZW~E`@c7t?HjnN?RQ+yw~gGTVE?Ew)!qp`0JT)WmWwbjG}xnKC%Bij
zlp(5un`G@J?~2<mQ%kZh-YjgpyxGj&2^JhMMXRS=viWkTz<yW9W4q}eQ|$M%`$H6g
zf=WvXVp5ZYwSAIJm7UM2bo;+diguC<ee9i}!5AeDjcy@6E*=g>Ar3Ak4n`nmO_Jb(
RhA@&0D@fLfg^NLe0|4BqxTOF9

literal 0
HcmV?d00001

diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..8e4bc49514548c604b36029f780f7ff1a56db148
GIT binary patch
literal 709
zcmd<!6cR}<N-W5TPb)3X%+HHYtw_u*$Vs(Y!pf!3#hRH{P+G#ppPN{cTbdJ}6kn2>
zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@
zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{|
zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8<tFB2Rwb4I!-7ke3!y-W
z4XTjCGbJ^zB(tPaNDLM%d9c6(hon|52P+4w0Hab$-!m6`jRR@+Zxv+hr(R;S@6LT?
z#~6@k?|97E-cIU)?e-H3>{d>;u)k9_-R68Ut6h7pqW!<lUv?LI4D6r&NwH7voMe|C
zRAJx#BEvrGrK`Q!YZ06Ovn=iJcyZa!i&L>TWKOYnLXR9R84eZ!b}QchFu=&j$m#@(
z2`v?{(?l2;8tl=r6FkVYlp(5YTTJaU71!F`j(KX=uE%OG<|u9N1Pe}>qNNu#?VDGY
z+kf6~Y%f=BV{d;x(hjT$6x3Qu5R=rr>+N}uHrviU9%8RkaM$M1!XSGmXf#BLLlc@1
o9~Tb?qYwud69*#@vnDBUK|@_gj7tP4BLI_u(u__lTnqvn0In0|n*aa+

literal 0
HcmV?d00001