Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Gelu and LayerNorm fusion to L1 optimization #21332

Merged
merged 16 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion onnxruntime/core/optimizer/gelu_fusion.cc
peishenyan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
bool gelu_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 20);
peishenyan marked this conversation as resolved.
Show resolved Hide resolved
const auto compatible_providers = GetCompatibleExecutionProviders();
if ((optimization_level_ == TransformerLevel::Level1 && !gelu_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && gelu_fusion_flag)) {

Check warning on line 52 in onnxruntime/core/optimizer/gelu_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/gelu_fusion.cc:52: Lines should be <= 120 characters long [whitespace/line_length] [2]
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

Expand Down Expand Up @@ -157,12 +166,13 @@
p_mul2_node = &mul2_node;
}

auto op_domain = optimization_level_ == TransformerLevel::Level1 ? kOnnxDomain : kMSDomain;
const std::array gelu_input_defs{div.MutableInputDefs()[0]};
Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"),
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
{}, {}, kMSDomain);
{}, {}, op_domain);

// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
Expand Down
18 changes: 16 additions & 2 deletions onnxruntime/core/optimizer/gelu_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,23 @@

*/
class GeluFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
std::string GetGeluFusionName(TransformerLevel level) {

Check warning on line 22 in onnxruntime/core/optimizer/gelu_fusion.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/gelu_fusion.h:22: Add #include <string> for string [build/include_what_you_use] [4]
switch (level) {
case TransformerLevel::Level1:
return "GeluFusionL1";
case TransformerLevel::Level2:
return "GeluFusionL2";
default:
return "GeluFusion";
}
}

public:
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers), optimization_level_(level) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));

transformers.emplace_back(std::make_unique<GeluFusion>());
transformers.emplace_back(std::make_unique<LayerNormFusion>());

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());

Expand Down Expand Up @@ -309,8 +312,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_js_eps));

transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
bool layernorm_fusion_flag = (onnx_version != version_map.end() && onnx_version->second >= 17);
const auto compatible_providers = GetCompatibleExecutionProviders();
if ((optimization_level_ == TransformerLevel::Level1 && !layernorm_fusion_flag) || (optimization_level_ == TransformerLevel::Level2 && layernorm_fusion_flag)) {

Check warning on line 147 in onnxruntime/core/optimizer/layer_norm_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/layer_norm_fusion.cc:147: Lines should be <= 120 characters long [whitespace/line_length] [2]
return Status::OK();
}

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down
18 changes: 16 additions & 2 deletions onnxruntime/core/optimizer/layer_norm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,23 @@

*/
class LayerNormFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
std::string GetLayerNormFusionName(TransformerLevel level) {

Check warning on line 22 in onnxruntime/core/optimizer/layer_norm_fusion.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/layer_norm_fusion.h:22: Add #include <string> for string [build/include_what_you_use] [4]
switch (level) {
case TransformerLevel::Level1:
return "LayerNormFusionL1";
case TransformerLevel::Level2:
return "LayerNormFusionL2";
default:
return "LayerNormFusion";
}
}

public:
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers), optimization_level_(level) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
58 changes: 52 additions & 6 deletions onnxruntime/test/optimizer/graph_transform_test.cc
peishenyan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -4484,7 +4484,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4495,14 +4499,40 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4520,7 +4550,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4538,7 +4572,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4556,8 +4594,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4572,8 +4614,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand Down
Loading
Loading