From d56fc7ebf5377abc96db728eafaffd8bf79a3b81 Mon Sep 17 00:00:00 2001
From: Abhishek Jindal <abjindal@microsoft.com>
Date: Thu, 21 Sep 2023 14:16:41 -0700
Subject: [PATCH] Layer norm fusion deepspeed stage3 changes (#17614)

### Description
<!-- Describe your changes. -->
Layer norm fusion changes required for deepspeed stage 3, also includes
test case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
It helps fusing layer norm for Deepspeed Stage 3. Added a test case
scenario which ensures that the fusion is working properly for the
scenario.
---
 .../core/optimizer/layer_norm_fusion.cc       |  42 ++++-----
 .../graph_transform_test_layernorm.cc         |  34 ++++++++
 .../fusion/layer_norm_fusion_scale_bias.onnx  | Bin 0 -> 854 bytes
 .../fusion/layer_norm_fusion_scale_bias.py    |  81 ++++++++++++++++++
 4 files changed, 136 insertions(+), 21 deletions(-)
 create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx
 create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py

diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc
index bf36f11521be2..159e3b23d1ab0 100644
--- a/onnxruntime/core/optimizer/layer_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc
@@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
     NodeArg* scale = nullptr;
     NodeArg* bias = nullptr;
     for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
-      if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
-          graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
-        if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
-          scale = mul_node.MutableInputDefs()[i];
-        }
+      if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
+        continue;
+      }
+      if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
+        scale = mul_node.MutableInputDefs()[i];
       }
     }
 
     for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
-      if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
-          graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
-        if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
-          bias = last_add_node.MutableInputDefs()[i];
-        }
+      if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
+        continue;
+      }
+      if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
+        bias = last_add_node.MutableInputDefs()[i];
       }
     }
     if (scale == nullptr || bias == nullptr) {
@@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
     // because SkipLayerNorm kernel, for example, has dependency on single dim size
     NodeArg* scale = nullptr;
     for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
-      if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
-          graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
+      if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
+        continue;
+      }
 #ifdef ENABLE_TRAINING_CORE
-        if (axes_values.empty() ||
-            mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
-          scale = mul_node.MutableInputDefs()[i];
-        }
+      if (axes_values.empty() ||
+          mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
+        scale = mul_node.MutableInputDefs()[i];
+      }
 #else
-        // Scale must be 1d.
-        if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
-          scale = mul_node.MutableInputDefs()[i];
-        }
-#endif
+      // Scale must be 1d.
+      if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
+        scale = mul_node.MutableInputDefs()[i];
       }
+#endif
     }
 
     if (scale == nullptr) {
diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
index 1f671e90090ba..a55238396cea3 100755
--- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
@@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
   }
 }
 
+// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph
+// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly
+TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx";
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
+
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  ASSERT_EQ(op_to_count["ReduceMean"], 0);
+  ASSERT_EQ(op_to_count["Sub"], 0);
+  ASSERT_EQ(op_to_count["Cast"], 0);
+  ASSERT_EQ(op_to_count["Pow"], 0);
+  ASSERT_EQ(op_to_count["Add"], 0);
+  ASSERT_EQ(op_to_count["Sqrt"], 0);
+  ASSERT_EQ(op_to_count["Div"], 0);
+  ASSERT_EQ(op_to_count["Mul"], 0);
+  ASSERT_EQ(op_to_count["LayerNormalization"], 1);
+
+  for (const Node& node : graph.Nodes()) {
+    if (node.OpType() == "LayerNormalization") {
+      // LayerNormalization should have three inputs.
+      EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size();
+      // LayerNormalization input "scale" and "bias" should have the same dimension.
+      const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape();
+      EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size();
+    }
+  }
+}
+
 // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization
 // doesn't support input and scale having different data types.
 TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) {
diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..ec0f9a97815b888701198d94c92e3f61c581d6dc
GIT binary patch
literal 854
zcmbVLO;6h}7_J-BxG#kTYZE93gnT22m1Yu$ooG5~nzW*c-nc|*?V^aLfqd|B%VCEd
z_!0b!9ry|SC$N(kk+Bn&96fr!@;r}i(*63k1K$A+X(!=+oM*O~2%gWxfWb)##v)ic
z9{~q9B0YN23*95r`2gfxhzlM@>6Q$%VOtJ@dJr|!d|FO4Bw)rQpTZvWW<i?ybq2^q
zeC>xz-=(HP>i32O%=i^w!!hU}H52Z>Cg;9~+&<_rur`aAl7<+#{``weNx=D_oR1Y^
z#*lN^ftN5P>1C2t1qv}dk>9s!bQLvucvY#9fEnMyE9gV-EQq4O4@;YCBkDS8M){&@
zkboKEd;z<lgJ9Kk5B>SzP?b?MvK3XgqUwV7nl}8kiFTXek@Vf^LOYAAqdEXhvhLB8
zJ7tgC=m2%NeOM_K(1s9uUCR>7EX-~h`N1m$Ln*TIxg<{C$gn@X&P!+h9YMQ4gIkdt
z$4TT+3o+bkwT`@(TjOl1*yF?9p4U84XNzD99K0cy*C0`6R*RdWK*euV{6StN>vU7S
p0tyxZ+JiQ+<ld1RPi12Czl4XOW%axbb)BNmQ8-KDG@fS`dIpt9{cQjM

literal 0
HcmV?d00001

diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py
new file mode 100644
index 0000000000000..a59e263763a3d
--- /dev/null
+++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py
@@ -0,0 +1,81 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import onnx
+from onnx import OperatorSetIdProto, TensorProto, helper
+
+
+def GenerateModel(model_name, has_casts=False, has_identity=False):  # noqa: N802
+    nodes = [  # LayerNorm subgraph
+        helper.make_node("ReduceMean", ["A"], ["rd_out"], "reduce1", axes=[-1], keepdims=1),
+        helper.make_node("Sub", ["A", "rd_out"], ["sub_out"], "sub"),
+        helper.make_node("Pow", ["cast_sub_out" if has_casts else "sub_out", "pow_in_2"], ["pow_out"], "pow"),
+        helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
+        helper.make_node("Add", ["rd2_out", "const_e12_f32"], ["add1_out"], "add1"),
+        helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
+        helper.make_node("Div", ["cast_sub_out" if has_casts else "sub_out", "sqrt_out"], ["div_out"], "div"),
+        helper.make_node(
+            "Mul",
+            ["gamma_id_out" if has_identity else "gamma", "cast_div_out" if has_casts else "div_out"],
+            ["mul_out"],
+            "mul",
+        ),
+        helper.make_node("Add", ["mul_out", "const_e6_f16_out" if has_identity else "const_e6_f16"], ["C"], "add2"),
+    ]
+
+    if has_casts:
+        nodes.extend(
+            [
+                helper.make_node("Cast", ["sub_out"], ["cast_sub_out"], "cast_sub", to=1),
+                helper.make_node("Cast", ["div_out"], ["cast_div_out"], "cast_2", to=10),
+            ]
+        )
+
+    if has_identity:
+        nodes.extend(
+            [
+                helper.make_node("Identity", ["gamma"], ["gamma_id_out"], "gamma_identity"),
+                helper.make_node("Identity", ["const_e6_f16"], ["const_e6_f16_out"], "const_e6_f16_identity"),
+            ]
+        )
+
+    initializers = [  # initializers
+        helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
+        helper.make_tensor("const_e12_f32", TensorProto.FLOAT, [], [1e-12]),
+        helper.make_tensor("const_e6_f16", TensorProto.FLOAT16, [4], [1e-6, 1e-6, 1e-6, 1e-6]),
+        helper.make_tensor(
+            "gamma",
+            TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT,
+            [4],
+            [1, 2, 3, 4],
+        ),
+    ]
+
+    input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT
+    output_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT
+
+    graph = helper.make_graph(
+        nodes,
+        "LayerNorm",  # name
+        [  # inputs
+            helper.make_tensor_value_info("A", input_type, [16, 32, 4]),
+        ],
+        [  # outputs
+            helper.make_tensor_value_info("C", output_type, [16, 32, 4]),
+        ],
+        initializers,
+    )
+
+    onnxdomain = OperatorSetIdProto()
+    onnxdomain.version = 12
+    # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
+    onnxdomain.domain = ""
+    msdomain = OperatorSetIdProto()
+    msdomain.version = 1
+    msdomain.domain = "com.microsoft"
+    opsets = [onnxdomain, msdomain]
+
+    model = helper.make_model(graph, opset_imports=opsets)
+    onnx.save(model, model_name)
+
+
+GenerateModel("layer_norm_fusion_scale_bias.onnx", True, True)