From 2cdc05f189bb34259deb3e3daef3289f1565558c Mon Sep 17 00:00:00 2001 From: Peishen Yan Date: Mon, 9 Sep 2024 11:27:52 +0800 Subject: [PATCH] Move Gelu and LayerNorm fusion to L1 optimization (#21332) According to https://github.com/microsoft/onnxruntime/issues/20915, we move the Gelu and LayerNorm fusion to L1 with a condition on the ONNX opset the model imports (LayerNorm requires opset 16+ and Gelu requires opset 20+.) If the opset version doesn't meet the requirements, the fusion is delayed to L2 optimization since the internal contrib op doesn't have a requirement for any specific ONNX opset. --------- Co-authored-by: Scott McKay Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- onnxruntime/core/optimizer/gelu_fusion.cc | 18 ++- onnxruntime/core/optimizer/gelu_fusion.h | 21 +++- .../core/optimizer/graph_transformer_utils.cc | 7 +- .../core/optimizer/layer_norm_fusion.cc | 15 +++ .../core/optimizer/layer_norm_fusion.h | 21 +++- .../test/optimizer/graph_transform_test.cc | 58 ++++++++- .../graph_transform_test_layernorm.cc | 115 ++++++++++++++---- .../transform/fusion/gelu_opset20.onnx | Bin 0 -> 486 bytes .../compute_optimizer/padding_elimination.cc | 1 + .../core/optimizer/graph_transformer_utils.cc | 4 +- .../test/optimizer/graph_transform_test.cc | 6 +- 11 files changed, 225 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/gelu_opset20.onnx diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index d09f0c9f027e2..641bfbf388623 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -44,6 +44,22 @@ static bool IsSupportedDataType(const Node& node) { [root]--> Gelu ==> */ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const auto& version_map = graph.DomainToVersionMap(); + const auto& onnx_version = version_map.find(kOnnxDomain); + // Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available + const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20); + const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_; + const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain; + + if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) || + // The following check assumes that there is a GeluFusion instance registered in Level1 that may have + // already done this fusion, in which case we don't need to do it again. + (optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) { + return Status::OK(); + } + + const auto compatible_providers = GetCompatibleExecutionProviders(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -162,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons "Gelu", "fused Gelu subgraphs ", gelu_input_defs, - {}, {}, kMSDomain); + {}, {}, op_domain); // Assign provider to this new node. Provider should be same as the provider for old node. gelu_node.SetExecutionProviderType(div.GetExecutionProviderType()); diff --git a/onnxruntime/core/optimizer/gelu_fusion.h b/onnxruntime/core/optimizer/gelu_fusion.h index 7573993cccad9..94bd84bcb0d03 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.h +++ b/onnxruntime/core/optimizer/gelu_fusion.h @@ -17,9 +17,26 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input. */ class GeluFusion : public GraphTransformer { + private: + TransformerLevel optimization_level_ = TransformerLevel::Level1; + bool allow_contrib_op_in_level_1_ = false; + std::string GetGeluFusionName(TransformerLevel level) { + switch (level) { + case TransformerLevel::Level1: + return "GeluFusionL1"; + case TransformerLevel::Level2: + return "GeluFusionL2"; + default: + return "GeluFusion"; + } + } + public: - GeluFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GeluFusion", compatible_execution_providers) {} + GeluFusion(const InlinedHashSet& compatible_execution_providers = {}, + TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept + : GraphTransformer(GetGeluFusionName(level), compatible_execution_providers), + optimization_level_(level), + allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 08284e67277e1..0530ab771e0be 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -235,6 +235,9 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique( session_options.free_dimension_overrides)); + transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); + if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique()); @@ -325,8 +328,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 48edf4854fbbb..3f19fb46e5ade 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -139,6 +139,21 @@ data are casted to float/double to calculate for precision, so if there is any C Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes. */ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const auto& version_map = graph.DomainToVersionMap(); + const auto& onnx_version = version_map.find(kOnnxDomain); + // LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available + const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17); + const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_; + + if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) || + // The following check assumes that there is a LayerNormFusion instance registered in Level1 that may have + // already done this fusion, in which case we don't need to do it again. + (optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) { + return Status::OK(); + } + + const auto compatible_providers = GetCompatibleExecutionProviders(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); InlinedVector> nodes_to_remove; diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index 18b176a802c15..c4673b8f9fbf5 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -17,9 +17,26 @@ The formula corresponding to LayerNorm activation subgraph: */ class LayerNormFusion : public GraphTransformer { + private: + TransformerLevel optimization_level_ = TransformerLevel::Level1; + bool allow_contrib_op_in_level_1_ = false; + std::string GetLayerNormFusionName(TransformerLevel level) { + switch (level) { + case TransformerLevel::Level1: + return "LayerNormFusionL1"; + case TransformerLevel::Level2: + return "LayerNormFusionL2"; + default: + return "LayerNormFusion"; + } + } + public: - LayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("LayerNormFusion", compatible_execution_providers) {} + LayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}, + TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept + : GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers), + optimization_level_(level), + allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 7ab0268c3509c..6ae66e35e7853 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4434,7 +4434,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -4445,6 +4449,28 @@ TEST_F(GraphTransformationTests, GeluFusionTest) { ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1); } +TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Erf"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["Gelu"] == 1); +} + TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx"; std::shared_ptr p_model; @@ -4452,7 +4478,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -4470,7 +4500,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -4488,7 +4522,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -4506,8 +4544,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -4522,8 +4564,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index a55238396cea3..2320a2321f8ff 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -41,7 +41,11 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -79,7 +83,11 @@ TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -94,7 +102,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -115,7 +127,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_2) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -131,7 +147,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_3) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -147,7 +167,11 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_4) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -169,7 +193,11 @@ TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -290,9 +318,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + const InlinedHashSet no_limit_empty_ep_list = {}; + std::unique_ptr transformer_1 = std::make_unique(); + std::unique_ptr transformer_2 = + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); } TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) { @@ -314,9 +347,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, nullptr, post_graph_checker)); + const InlinedHashSet no_limit_empty_ep_list = {}; + std::unique_ptr transformer_1 = std::make_unique(); + std::unique_ptr transformer_2 = + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); } TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) { @@ -341,9 +379,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, nullptr, post_graph_checker)); + const InlinedHashSet no_limit_empty_ep_list = {}; + std::unique_ptr transformer_1 = std::make_unique(); + std::unique_ptr transformer_2 = + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); } TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) { @@ -365,9 +408,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, nullptr, post_graph_checker)); + const InlinedHashSet no_limit_empty_ep_list = {}; + std::unique_ptr transformer_1 = std::make_unique(); + std::unique_ptr transformer_2 = + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); } TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) { @@ -393,9 +441,14 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, nullptr, post_graph_checker)); + const InlinedHashSet no_limit_empty_ep_list = {}; + std::unique_ptr transformer_1 = std::make_unique(); + std::unique_ptr transformer_2 = + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer_1), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer_2), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); } TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { @@ -438,7 +491,11 @@ TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); @@ -529,8 +586,12 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); std::map op_to_count = CountOpsInGraph(graph); @@ -579,8 +640,12 @@ static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_stringMainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); for (Node& node : graph.Nodes()) { diff --git a/onnxruntime/test/testdata/transform/fusion/gelu_opset20.onnx b/onnxruntime/test/testdata/transform/fusion/gelu_opset20.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7566fe3b75d38358bb032246ec9adf58de470ad1 GIT binary patch literal 486 zcmd;J5n?Y%Gs@4)tB_(f)HBmFusY4mCC$Yc!NqK7X(9w9Or_XeGRxu(jZBr8febDQ zkc=rv#!QOcwI~fJW2VFmWWcoX0EJ&Lfh%g69TS&1xrlbI+ zEtHsn42ZNbNZJAfag7-m9Lnh~wim}r$2M5uIP;bIUF0RT!mTulG~ literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 4b6a9a6e594cd..d0895843eee7e 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -285,6 +285,7 @@ void IterateSubgraphFromNode(Graph& graph, PushAllOutputNode(graph, to_visit, cur, visited); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {20}) || graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) { diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 589e7be455dbc..29a309920c74b 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -121,7 +121,7 @@ std::vector> GeneratePreTrainingTransformers( // CSE will not merge them. transformers.emplace_back(std::make_unique(compatible_eps)); // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input. - transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps, level, true)); // Remove duplicate nodes. Must be applied before any recompute transformations. if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) { transformers.emplace_back(std::make_unique(compatible_eps)); @@ -129,7 +129,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); } - transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps, level, true)); #if defined(USE_CUDA) || defined(USE_ROCM) transformers.emplace_back(std::make_unique(compatible_eps, true /* skip_device_check*/)); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index b2ab4891f2e1e..2ec77c96dc2d5 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -919,9 +919,13 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph);