Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Gelu and LayerNorm fusion to L1 optimization #21332

Merged
merged 16 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion onnxruntime/core/optimizer/gelu_fusion.cc
peishenyan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ static bool IsSupportedDataType(const Node& node) {
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20);
const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_;
const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain;

if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a GeluFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

Expand Down Expand Up @@ -162,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
{}, {}, kMSDomain);
{}, {}, op_domain);

// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/optimizer/gelu_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@

*/
class GeluFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetGeluFusionName(TransformerLevel level) {

Check warning on line 23 in onnxruntime/core/optimizer/gelu_fusion.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/gelu_fusion.h:23: Add #include <string> for string [build/include_what_you_use] [4]
switch (level) {
case TransformerLevel::Level1:
return "GeluFusionL1";
case TransformerLevel::Level2:
return "GeluFusionL2";
default:
return "GeluFusion";
}
}

public:
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));

transformers.emplace_back(std::make_unique<GeluFusion>());
transformers.emplace_back(std::make_unique<LayerNormFusion>());

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());

Expand Down Expand Up @@ -325,8 +328,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));

transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ data are casted to float/double to calculate for precision, so if there is any C
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17);
const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_;

if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a LayerNormFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/optimizer/layer_norm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@

*/
class LayerNormFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetLayerNormFusionName(TransformerLevel level) {

Check warning on line 23 in onnxruntime/core/optimizer/layer_norm_fusion.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/layer_norm_fusion.h:23: Add #include <string> for string [build/include_what_you_use] [4]
switch (level) {
case TransformerLevel::Level1:
return "LayerNormFusionL1";
case TransformerLevel::Level2:
return "LayerNormFusionL2";
default:
return "LayerNormFusion";
}
}

public:
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
58 changes: 52 additions & 6 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4434,7 +4434,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4445,14 +4449,40 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4470,7 +4500,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4488,7 +4522,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4506,8 +4544,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4522,8 +4564,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand Down
Loading
Loading