Skip to content

Commit

Permalink
Layer norm fusion deepspeed stage3 changes (#17614)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Layer norm fusion changes required for deepspeed stage 3, also includes
test case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
It helps fusing layer norm for Deepspeed Stage 3. Added a test case
scenario which ensures that the fusion is working properly for the
scenario.
  • Loading branch information
ajindal1 authored Sep 21, 2023
1 parent f299016 commit d56fc7e
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 21 deletions.
42 changes: 21 additions & 21 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
}
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
}

for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) {
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
}
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
}
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
}
}
if (scale == nullptr || bias == nullptr) {
Expand Down Expand Up @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
}
#ifdef ENABLE_TRAINING_CORE
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
#else
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
}

if (scale == nullptr) {
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
}
}

// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph
// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly
TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["ReduceMean"], 0);
ASSERT_EQ(op_to_count["Sub"], 0);
ASSERT_EQ(op_to_count["Cast"], 0);
ASSERT_EQ(op_to_count["Pow"], 0);
ASSERT_EQ(op_to_count["Add"], 0);
ASSERT_EQ(op_to_count["Sqrt"], 0);
ASSERT_EQ(op_to_count["Div"], 0);
ASSERT_EQ(op_to_count["Mul"], 0);
ASSERT_EQ(op_to_count["LayerNormalization"], 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() == "LayerNormalization") {
// LayerNormalization should have three inputs.
EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size();
// LayerNormalization input "scale" and "bias" should have the same dimension.
const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape();
EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size();
}
}
}

// If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization
// doesn't support input and scale having different data types.
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) {
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper


def GenerateModel(model_name, has_casts=False, has_identity=False): # noqa: N802
nodes = [ # LayerNorm subgraph
helper.make_node("ReduceMean", ["A"], ["rd_out"], "reduce1", axes=[-1], keepdims=1),
helper.make_node("Sub", ["A", "rd_out"], ["sub_out"], "sub"),
helper.make_node("Pow", ["cast_sub_out" if has_casts else "sub_out", "pow_in_2"], ["pow_out"], "pow"),
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1),
helper.make_node("Add", ["rd2_out", "const_e12_f32"], ["add1_out"], "add1"),
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
helper.make_node("Div", ["cast_sub_out" if has_casts else "sub_out", "sqrt_out"], ["div_out"], "div"),
helper.make_node(
"Mul",
["gamma_id_out" if has_identity else "gamma", "cast_div_out" if has_casts else "div_out"],
["mul_out"],
"mul",
),
helper.make_node("Add", ["mul_out", "const_e6_f16_out" if has_identity else "const_e6_f16"], ["C"], "add2"),
]

if has_casts:
nodes.extend(
[
helper.make_node("Cast", ["sub_out"], ["cast_sub_out"], "cast_sub", to=1),
helper.make_node("Cast", ["div_out"], ["cast_div_out"], "cast_2", to=10),
]
)

if has_identity:
nodes.extend(
[
helper.make_node("Identity", ["gamma"], ["gamma_id_out"], "gamma_identity"),
helper.make_node("Identity", ["const_e6_f16"], ["const_e6_f16_out"], "const_e6_f16_identity"),
]
)

initializers = [ # initializers
helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
helper.make_tensor("const_e12_f32", TensorProto.FLOAT, [], [1e-12]),
helper.make_tensor("const_e6_f16", TensorProto.FLOAT16, [4], [1e-6, 1e-6, 1e-6, 1e-6]),
helper.make_tensor(
"gamma",
TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT,
[4],
[1, 2, 3, 4],
),
]

input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT
output_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT

graph = helper.make_graph(
nodes,
"LayerNorm", # name
[ # inputs
helper.make_tensor_value_info("A", input_type, [16, 32, 4]),
],
[ # outputs
helper.make_tensor_value_info("C", output_type, [16, 32, 4]),
],
initializers,
)

onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]

model = helper.make_model(graph, opset_imports=opsets)
onnx.save(model, model_name)


GenerateModel("layer_norm_fusion_scale_bias.onnx", True, True)

0 comments on commit d56fc7e

Please sign in to comment.