diff --git a/onnxruntime/core/optimizer/gelu_fusion.cc b/onnxruntime/core/optimizer/gelu_fusion.cc index 53e60ec0fd3eb..6ddf5e114ff97 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.cc +++ b/onnxruntime/core/optimizer/gelu_fusion.cc @@ -47,17 +47,17 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons const auto& version_map = graph.DomainToVersionMap(); const auto& onnx_version = version_map.find(kOnnxDomain); // Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available - bool gelu_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 20); - const auto compatible_providers = GetCompatibleExecutionProviders(); - auto op_domain = optimization_level_ == TransformerLevel::Level1 ? kOnnxDomain : kMSDomain; + const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20); + const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_; + const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain; - if (contrib_flag_) { - op_domain = kMSDomain; - } - else if ((optimization_level_ == TransformerLevel::Level1 && !gelu_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && gelu_fusion_flag)) { + if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) || + (optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) { return Status::OK(); } + const auto compatible_providers = GetCompatibleExecutionProviders(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/gelu_fusion.h b/onnxruntime/core/optimizer/gelu_fusion.h index 0c5cb27243d36..94bd84bcb0d03 100644 --- a/onnxruntime/core/optimizer/gelu_fusion.h +++ b/onnxruntime/core/optimizer/gelu_fusion.h @@ -19,7 +19,7 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input. class GeluFusion : public GraphTransformer { private: TransformerLevel optimization_level_ = TransformerLevel::Level1; - bool contrib_flag_ = false; + bool allow_contrib_op_in_level_1_ = false; std::string GetGeluFusionName(TransformerLevel level) { switch (level) { case TransformerLevel::Level1: @@ -33,8 +33,10 @@ class GeluFusion : public GraphTransformer { public: GeluFusion(const InlinedHashSet& compatible_execution_providers = {}, - TransformerLevel level = TransformerLevel::Level1, bool contrib_flag = false) noexcept - : GraphTransformer(GetGeluFusionName(level), compatible_execution_providers), optimization_level_(level), contrib_flag_(contrib_flag) {} + TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept + : GraphTransformer(GetGeluFusionName(level), compatible_execution_providers), + optimization_level_(level), + allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; }; diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 24af5dc611f50..e0aec7062fc74 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -142,15 +142,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const auto& version_map = graph.DomainToVersionMap(); const auto& onnx_version = version_map.find(kOnnxDomain); // LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available - bool layernorm_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 17); - const auto compatible_providers = GetCompatibleExecutionProviders(); + const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17); + const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_; - if (!contrib_flag_) { - if ((optimization_level_ == TransformerLevel::Level1 && !layernorm_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && layernorm_fusion_flag)) { - return Status::OK(); - } + if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) || + (optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) { + return Status::OK(); } + const auto compatible_providers = GetCompatibleExecutionProviders(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); InlinedVector> nodes_to_remove; diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index a2797014f3dd2..c4673b8f9fbf5 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -19,7 +19,7 @@ The formula corresponding to LayerNorm activation subgraph: class LayerNormFusion : public GraphTransformer { private: TransformerLevel optimization_level_ = TransformerLevel::Level1; - bool contrib_flag_ = false; + bool allow_contrib_op_in_level_1_ = false; std::string GetLayerNormFusionName(TransformerLevel level) { switch (level) { case TransformerLevel::Level1: @@ -33,8 +33,10 @@ class LayerNormFusion : public GraphTransformer { public: LayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}, - TransformerLevel level = TransformerLevel::Level1, bool contrib_flag = false) noexcept - : GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers), optimization_level_(level), contrib_flag_(contrib_flag) {} + TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept + : GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers), + optimization_level_(level), + allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; };