From d56fc7ebf5377abc96db728eafaffd8bf79a3b81 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal <abjindal@microsoft.com> Date: Thu, 21 Sep 2023 14:16:41 -0700 Subject: [PATCH] Layer norm fusion deepspeed stage3 changes (#17614) ### Description <!-- Describe your changes. --> Layer norm fusion changes required for deepspeed stage 3, also includes test case. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It helps fusing layer norm for Deepspeed Stage 3. Added a test case scenario which ensures that the fusion is working properly for the scenario. --- .../core/optimizer/layer_norm_fusion.cc | 42 ++++----- .../graph_transform_test_layernorm.cc | 34 ++++++++ .../fusion/layer_norm_fusion_scale_bias.onnx | Bin 0 -> 854 bytes .../fusion/layer_norm_fusion_scale_bias.py | 81 ++++++++++++++++++ 4 files changed, 136 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521be2..159e3b23d1ab0 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { + bias = last_add_node.MutableInputDefs()[i]; } } if (scale == nullptr || bias == nullptr) { @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } #ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (axes_values.empty() || + mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; + } #else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif + // Scale must be 1d. + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { + scale = mul_node.MutableInputDefs()[i]; } +#endif } if (scale == nullptr) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 1f671e90090ba..a55238396cea3 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph +// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly +TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx"; + std::shared_ptr<Model> p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map<std::string, int> op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["ReduceMean"], 0); + ASSERT_EQ(op_to_count["Sub"], 0); + ASSERT_EQ(op_to_count["Cast"], 0); + ASSERT_EQ(op_to_count["Pow"], 0); + ASSERT_EQ(op_to_count["Add"], 0); + ASSERT_EQ(op_to_count["Sqrt"], 0); + ASSERT_EQ(op_to_count["Div"], 0); + ASSERT_EQ(op_to_count["Mul"], 0); + ASSERT_EQ(op_to_count["LayerNormalization"], 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } + } +} + // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization // doesn't support input and scale having different data types. TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ec0f9a97815b888701198d94c92e3f61c581d6dc GIT binary patch literal 854 zcmbVLO;6h}7_J-BxG#kTYZE93gnT22m1Yu$ooG5~nzW*c-nc|*?V^aLfqd|B%VCEd z_!0b!9ry|SC$N(kk+Bn&96fr!@;r}i(*63k1K$A+X(!=+oM*O~2%gWxfWb)##v)ic z9{~q9B0YN23*95r`2gfxhzlM@>6Q$%VOtJ@dJr|!d|FO4Bw)rQpTZvWW<i?ybq2^q zeC>xz-=(HP>i32O%=i^w!!hU}H52Z>Cg;9~+&<_rur`aAl7<+#{``weNx=D_oR1Y^ z#*lN^ftN5P>1C2t1qv}dk>9s!bQLvucvY#9fEnMyE9gV-EQq4O4@;YCBkDS8M){&@ zkboKEd;z<lgJ9Kk5B>SzP?b?MvK3XgqUwV7nl}8kiFTXek@Vf^LOYAAqdEXhvhLB8 zJ7tgC=m2%NeOM_K(1s9uUCR>7EX-~h`N1m$Ln*TIxg<{C$gn@X&P!+h9YMQ4gIkdt z$4TT+3o+bkwT`@(TjOl1*yF?9p4U84XNzD99K0cy*C0`6R*RdWK*euV{6StN>vU7S p0tyxZ+JiQ+<ld1RPi12Czl4XOW%axbb)BNmQ8-KDG@fS`dIpt9{cQjM literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py new file mode 100644 index 0000000000000..a59e263763a3d --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + + +def GenerateModel(model_name, has_casts=False, has_identity=False): # noqa: N802 + nodes = [ # LayerNorm subgraph + helper.make_node("ReduceMean", ["A"], ["rd_out"], "reduce1", axes=[-1], keepdims=1), + helper.make_node("Sub", ["A", "rd_out"], ["sub_out"], "sub"), + helper.make_node("Pow", ["cast_sub_out" if has_casts else "sub_out", "pow_in_2"], ["pow_out"], "pow"), + helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1), + helper.make_node("Add", ["rd2_out", "const_e12_f32"], ["add1_out"], "add1"), + helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"), + helper.make_node("Div", ["cast_sub_out" if has_casts else "sub_out", "sqrt_out"], ["div_out"], "div"), + helper.make_node( + "Mul", + ["gamma_id_out" if has_identity else "gamma", "cast_div_out" if has_casts else "div_out"], + ["mul_out"], + "mul", + ), + helper.make_node("Add", ["mul_out", "const_e6_f16_out" if has_identity else "const_e6_f16"], ["C"], "add2"), + ] + + if has_casts: + nodes.extend( + [ + helper.make_node("Cast", ["sub_out"], ["cast_sub_out"], "cast_sub", to=1), + helper.make_node("Cast", ["div_out"], ["cast_div_out"], "cast_2", to=10), + ] + ) + + if has_identity: + nodes.extend( + [ + helper.make_node("Identity", ["gamma"], ["gamma_id_out"], "gamma_identity"), + helper.make_node("Identity", ["const_e6_f16"], ["const_e6_f16_out"], "const_e6_f16_identity"), + ] + ) + + initializers = [ # initializers + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_e12_f32", TensorProto.FLOAT, [], [1e-12]), + helper.make_tensor("const_e6_f16", TensorProto.FLOAT16, [4], [1e-6, 1e-6, 1e-6, 1e-6]), + helper.make_tensor( + "gamma", + TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT, + [4], + [1, 2, 3, 4], + ), + ] + + input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT + output_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT + + graph = helper.make_graph( + nodes, + "LayerNorm", # name + [ # inputs + helper.make_tensor_value_info("A", input_type, [16, 32, 4]), + ], + [ # outputs + helper.make_tensor_value_info("C", output_type, [16, 32, 4]), + ], + initializers, + ) + + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 12 + # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. + onnxdomain.domain = "" + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + opsets = [onnxdomain, msdomain] + + model = helper.make_model(graph, opset_imports=opsets) + onnx.save(model, model_name) + + +GenerateModel("layer_norm_fusion_scale_bias.onnx", True, True)