Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
peishenyan committed Sep 4, 2024
1 parent ba663d1 commit 019b4af
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/core/optimizer/gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
bool gelu_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 20);
const auto compatible_providers = GetCompatibleExecutionProviders();
auto op_domain = optimization_level_ == TransformerLevel::Level1 ? kOnnxDomain : kMSDomain;
const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20);
const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_;
const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain;

if (contrib_flag_) {
op_domain = kMSDomain;
}
else if ((optimization_level_ == TransformerLevel::Level1 && !gelu_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && gelu_fusion_flag)) {
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/optimizer/gelu_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
class GeluFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool contrib_flag_ = false;
bool allow_contrib_op_in_level_1_ = false;
std::string GetGeluFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
Expand All @@ -33,8 +33,10 @@ class GeluFusion : public GraphTransformer {

public:
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool contrib_flag = false) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers), optimization_level_(level), contrib_flag_(contrib_flag) {}
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
bool layernorm_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 17);
const auto compatible_providers = GetCompatibleExecutionProviders();
const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17);
const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_;

if (!contrib_flag_) {
if ((optimization_level_ == TransformerLevel::Level1 && !layernorm_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && layernorm_fusion_flag)) {
return Status::OK();
}
if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/optimizer/layer_norm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The formula corresponding to LayerNorm activation subgraph:
class LayerNormFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool contrib_flag_ = false;
bool allow_contrib_op_in_level_1_ = false;
std::string GetLayerNormFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
Expand All @@ -33,8 +33,10 @@ class LayerNormFusion : public GraphTransformer {

public:
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool contrib_flag = false) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers), optimization_level_(level), contrib_flag_(contrib_flag) {}
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down

0 comments on commit 019b4af

Please sign in to comment.