From 8fc0c41462af468876da76bce5efebd5b7dcf107 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 00:43:07 -0700 Subject: [PATCH 01/18] keep original name as much as possible during fusion --- onnxruntime/core/optimizer/gather_fusion.cc | 2 +- onnxruntime/core/optimizer/gemm_transpose_fusion.cc | 2 +- onnxruntime/core/optimizer/layer_norm_fusion.cc | 4 ++-- onnxruntime/core/optimizer/matmul_scale_fusion.cc | 2 +- onnxruntime/core/optimizer/matmul_transpose_fusion.cc | 6 +++--- onnxruntime/core/optimizer/quick_gelu_fusion.cc | 2 +- .../orttraining/core/optimizer/concat_replacement.cc | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 90cabff88122c..e8b10f6d9c289 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -273,7 +273,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + Node& split_node = graph.AddNode(nodes_to_fuse[0]->Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes", {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); split_node.AddAttribute("axis", axis); split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index b97cce9c2e785..a52517d23db86 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -75,7 +75,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m nodes_to_remove.push_back(output_node); } - Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_transformed"), + Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmTransposeFusion/"), gemm_node.OpType(), "Fused Gemm with Transpose", new_gemm_input_defs, diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index ce696154adb6d..48edf4854fbbb 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -455,7 +455,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } InlinedVector layer_norm_input_defs{x_input, scale, bias}; - Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"), + Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/LayerNormFusion/"), "LayerNormalization", "fused LayerNorm subgraphs ", layer_norm_input_defs, @@ -705,7 +705,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr InlinedVector layer_norm_input_defs{x_input, scale}; Node& layer_norm_node = - graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization", + graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/SimplifiedLayerNormFusion/"), "SimplifiedLayerNormalization", "fused LayerNorm subgraphs ", layer_norm_input_defs, {}, {}, kOnnxDomain); // Get constant "epsilon" from "Add" node if available. Else, default value will be used. diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index b04d794cc9469..e4cdeadbf54d7 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -245,7 +245,7 @@ Status ProcessNode( } Node& matmul_scale_node = graph.AddNode( - graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"), + graph.GenerateNodeName(node.Name() + "/MatMulScaleFusion/"), "FusedMatMul", "Fused MatMul and Scale", fused_node_inputs, diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 789466778edc6..8eb224013618d 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -154,14 +154,14 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast(cast_output->TypeAsProto()->tensor_type().elem_type()); new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type); - auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto); + auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "/MatmulTransposeFusion/", &new_cast_output_type_proto); const std::array new_cast_input_defs{transpose_input}; const std::array new_cast_output_defs{&new_cast_output}; const std::array new_transpose_input_defs = {&new_cast_output}; const std::array new_transpose_output_defs = {cast_output}; - Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"), + Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "/MatmulTransposeFusion/"), cast->OpType(), "Created a new Cast node to interchange Cast and Transpose nodes", new_cast_input_defs, @@ -385,7 +385,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ const std::array input_defs{left_input, right_input}; const std::array output_defs{node.MutableOutputDefs()[0]}; - Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"), + Node& matmul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "/MatmulTransposeFusion/"), "FusedMatMul", "fused MatMul and Transpose ", input_defs, diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.cc b/onnxruntime/core/optimizer/quick_gelu_fusion.cc index 6e5eb5612a701..b09ef1c460b8e 100644 --- a/onnxruntime/core/optimizer/quick_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/quick_gelu_fusion.cc @@ -88,7 +88,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0]; Node& quick_gelu_node = - graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg}, + graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/QuickGeluFusion/"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg}, std::array{quick_gelu_output_arg}, {}, kMSDomain); quick_gelu_node.AddAttribute("alpha", alpha); quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType()); diff --git a/orttraining/orttraining/core/optimizer/concat_replacement.cc b/orttraining/orttraining/core/optimizer/concat_replacement.cc index 37d302765cda8..2c919591ec081 100644 --- a/orttraining/orttraining/core/optimizer/concat_replacement.cc +++ b/orttraining/orttraining/core/optimizer/concat_replacement.cc @@ -23,7 +23,7 @@ Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffe concat_outputs.push_back(&ip_shape_op); - Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"), + Node& concat_training_node = graph.AddNode(graph.GenerateNodeName(concat_node.Name() + "/ConcatReplacement/"), "ConcatTraining", "Concat with extra output", concat_inputs, From bf572c4a76bfe51b7a84e228cf0cd4d6a79dd95e Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 00:46:30 -0700 Subject: [PATCH 02/18] minor --- onnxruntime/core/optimizer/gather_fusion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index e8b10f6d9c289..1f2b31526c6b8 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -273,7 +273,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - Node& split_node = graph.AddNode(nodes_to_fuse[0]->Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes", + Node& split_node = graph.AddNode(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes", {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); split_node.AddAttribute("axis", axis); split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); From 27b872f2420b646c053e8a870845b9804a1d68d2 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 01:43:38 -0700 Subject: [PATCH 03/18] fix boudary detect --- .../memory_optimizer/memory_insight.cc | 2 +- .../memory_optimizer/recompute_analysis.cc | 33 +++++++-- .../memory_optimizer/recompute_analysis.h | 2 +- .../memory_optimizer/transformer_specific.cc | 70 +++++++++++-------- .../memory_optimizer/transformer_specific.h | 2 +- 5 files changed, 70 insertions(+), 39 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 54c49db0597c7..596742ddb9846 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -257,7 +257,7 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, is_forward_nodes, logger)); - InlinedHashSet layer_boundary_ln_nodes; + InlinedVector layer_boundary_ln_nodes; FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, layer_boundary_ln_nodes); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index b421eb2ab32da..9edfb148a019d 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -386,6 +386,26 @@ const InlinedHashMap& GetAllowedRecompu {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, + // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, { utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), { @@ -691,7 +711,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation) { @@ -709,13 +729,14 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap auto output_name = node.OutputDefs()[output_index]->Name(); auto consumers = graph_viewer.GetConsumerNodes(output_name); for (auto& consumer : consumers) { - if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) { + if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), consumer) != + layer_boundary_ln_nodes.end()) { int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); if (dest_in_index == 0) { - LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType() - << ") is a Attention+MLP layer boundary node, " - << "its stashed activation outputs are used by LayerNormalization's inputs, " - << "we don't need to recompute it."; + std::cout << "Node " << node.Name() << "(" << node.OpType() + << ") is a Attention+MLP layer boundary node, " + << "its stashed activation outputs are used by LayerNormalization's inputs, " + << "we don't need to recompute it." << std::endl; return nullptr; } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ab114d970191e..ac1021f5eb83b 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -164,7 +164,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index c88a0f05d36b8..99680c3cd0bee 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -22,14 +22,22 @@ void FindLayerBoundaryLayerNormNodes( const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes) { + InlinedVector& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. // If the found node is another LayerNormalization, the LayerNormalization node as MLP. - const InlinedHashSet softmax_ops{"Softmax", "BiasSoftmax"}; - const InlinedHashSet layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; + const InlinedHashSet softmax_ops{ + "Softmax", + "BiasSoftmax", + }; + const InlinedHashSet layernorm_ops{ + "LayerNormalization", + "SkipLayerNormalization", + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + }; layer_boundary_ln_nodes.clear(); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); @@ -49,10 +57,11 @@ void FindLayerBoundaryLayerNormNodes( } } - bool unexpected_failure = false; + // For a perfect layer match, all those three flags should be true after exiting while loop. + // except the layernorm after the last layer, which expects no more softmax and layernorm are found. bool found_softmax = false; - bool found_layernorm = false; - ptrdiff_t next_layernorm_execution_oder = -1; + bool found_mid_layernorm = false; + bool found_end_layernorm = false; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -65,37 +74,38 @@ void FindLayerBoundaryLayerNormNodes( if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - if (found_layernorm) { - // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. - unexpected_failure = true; - break; - } - found_layernorm = true; // don't trace further - next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); - continue; - } else { - for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { - // Stop if the node is after next Layernorm node in execution order. - if (found_layernorm && - node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { - continue; - } - nodes_to_check.push_back(&(*node_it)); + // mid layernorm MUST appear before end layernorm, because it is a single in and out connector (despite of its weights) + if (found_mid_layernorm) { + found_end_layernorm = true; + // std::cout << "Found end LayerNormalization node 3333." << next_node->Name() << std::endl; + break; // exit the while loop since we found the end LayerNormalization node. + } else { + // std::cout << "Found mid LayerNormalization node 4444." << next_node->Name() << std::endl; + found_mid_layernorm = true; } } - } - if (unexpected_failure) { - layer_boundary_ln_nodes.clear(); - break; + for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= yield_op_order_in_topological_sort) { + continue; + } + nodes_to_check.push_back(&(*node_it)); + } } - if (found_softmax) { - layer_boundary_ln_nodes.insert(&node); - } else if (!found_layernorm) { + if (found_softmax && found_mid_layernorm && found_end_layernorm) { + // std::cout << "Found LayerNormalization node 1111." << std::endl; + if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { + layer_boundary_ln_nodes.push_back(&node); + } + } else if (!found_softmax && !found_mid_layernorm && !found_end_layernorm) { + // std::cout << "Found LayerNormalization node 2222." << found_mid_layernorm << found_mid_layernorm << found_end_layernorm << std::endl; // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, // we also consider it as boundary node. - layer_boundary_ln_nodes.insert(&node); + if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { + layer_boundary_ln_nodes.push_back(&node); + } } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index b58d822124f43..a72e5a0af92d3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -23,6 +23,6 @@ void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes); + InlinedVector& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer From 1330a40f76dfa307103c0a93595376baf47dc20d Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 09:58:10 +0000 Subject: [PATCH 04/18] Add warning if the boudary node is not found if memory optimizer level = 1 (layer wise recompute) --- .../memory_optimizer/memory_insight.cc | 5 +++ .../memory_optimizer/recompute_analysis.cc | 8 ++-- .../test/optimizer/memory_optimizer_test.cc | 42 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 596742ddb9846..3d0fa942fd2d4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -261,6 +261,11 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, layer_boundary_ln_nodes); + if (probe_config.enable_transformer_layer_as_boundary && layer_boundary_ln_nodes.size() == 0) { + LOGS(logger, WARNING) << "No transformer layer boundary nodes found, this might cause memory optimization " + "not working as expected. Please check the model and the configuration."; + } + // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { const Node* p_node = graph_viewer.GetNode(node_ids[i]); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 9edfb148a019d..37ac1c4950ecd 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -733,10 +733,10 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap layer_boundary_ln_nodes.end()) { int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); if (dest_in_index == 0) { - std::cout << "Node " << node.Name() << "(" << node.OpType() - << ") is a Attention+MLP layer boundary node, " - << "its stashed activation outputs are used by LayerNormalization's inputs, " - << "we don't need to recompute it." << std::endl; + MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() + + ") is a Attention+MLP layer boundary node, " + + "its stashed activation outputs are used by LayerNormalization's inputs, " + + "we don't need to recompute it."); return nullptr; } } diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 22f1da1327547..32304d4ffd828 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -29,6 +29,7 @@ #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" +#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -312,5 +313,46 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { } } +TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + GraphViewer graph_viewer(graph); + + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + + // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. + ptrdiff_t yield_op_order_in_topological_sort = -1; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + if (p_node->OpType() == "YieldOp") { + // There are multiple YieldOps in the graph。 + ASSERT_TRUE(yield_op_order_in_topological_sort == -1); + yield_op_order_in_topological_sort = static_cast(i); + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast(i); + } + + InlinedVector layer_boundary_ln_node; + optimizer::memory_optimizer::FindLayerBoundaryLayerNormNodes(graph_viewer, *logger, + node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, + layer_boundary_ln_node); + + ASSERT_TRUE(layer_boundary_ln_node.size() == 4); + ASSERT_TRUE(layer_boundary_ln_node[0]->Name() == "LayerNormalization_token_0"); + ASSERT_TRUE(layer_boundary_ln_node[1]->Name() == "LayerNormalization_token_6"); + ASSERT_TRUE(layer_boundary_ln_node[2]->Name() == "LayerNormalization_token_12"); + ASSERT_TRUE(layer_boundary_ln_node[3]->Name() == "LayerNormalization_token_18"); +} + } // namespace test } // namespace onnxruntime From aee8365bbf43f32191a6a6f543a62d2a1255c510 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 10:02:02 +0000 Subject: [PATCH 05/18] minor --- .../core/optimizer/memory_optimizer/transformer_specific.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 99680c3cd0bee..39dce684fe014 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -11,7 +11,6 @@ #include "core/optimizer/utils.h" #include "core/graph/graph_viewer.h" #include "core/framework/tensorprotoutils.h" - #include "core/common/string_utils.h" namespace onnxruntime::optimizer::memory_optimizer { @@ -77,10 +76,8 @@ void FindLayerBoundaryLayerNormNodes( // mid layernorm MUST appear before end layernorm, because it is a single in and out connector (despite of its weights) if (found_mid_layernorm) { found_end_layernorm = true; - // std::cout << "Found end LayerNormalization node 3333." << next_node->Name() << std::endl; break; // exit the while loop since we found the end LayerNormalization node. } else { - // std::cout << "Found mid LayerNormalization node 4444." << next_node->Name() << std::endl; found_mid_layernorm = true; } } @@ -95,12 +92,10 @@ void FindLayerBoundaryLayerNormNodes( } if (found_softmax && found_mid_layernorm && found_end_layernorm) { - // std::cout << "Found LayerNormalization node 1111." << std::endl; if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { layer_boundary_ln_nodes.push_back(&node); } } else if (!found_softmax && !found_mid_layernorm && !found_end_layernorm) { - // std::cout << "Found LayerNormalization node 2222." << found_mid_layernorm << found_mid_layernorm << found_end_layernorm << std::endl; // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, // we also consider it as boundary node. if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { From 73be0395c89686c518f026b990145a4baa184f3a Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 27 Mar 2024 13:25:14 +0000 Subject: [PATCH 06/18] fix --- .../memory_optimizer/transformer_specific.cc | 159 +++++++++++++----- .../test/optimizer/memory_optimizer_test.cc | 3 +- 2 files changed, 116 insertions(+), 46 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 39dce684fe014..535ac77588104 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -11,13 +11,95 @@ #include "core/optimizer/utils.h" #include "core/graph/graph_viewer.h" #include "core/framework/tensorprotoutils.h" + #include "core/common/string_utils.h" namespace onnxruntime::optimizer::memory_optimizer { +namespace { + +bool IsLayerNormNode(const Node& node) { + const static std::set layer_norm_ops = { + "LayerNormalization", + "SkipLayerNormalization", + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + }; + return layer_norm_ops.find(node.OpType()) != layer_norm_ops.end(); +} + +bool IsSoftmaxNode(const Node& node) { + const static std::set softmax_ops = { + "Softmax", + "BiasSoftmax", + }; + return softmax_ops.find(node.OpType()) != softmax_ops.end(); +} + +std::tuple IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) { + auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name()); + if (2 > consumers.size()) { + return std::make_tuple(false, nullptr, nullptr); + } + + // Find the Add node from the consumer list. + const Node* add_node = nullptr; + const Node* other_node = nullptr; + for (const auto* consumer : consumers) { + if (consumer->OpType() == "Add") { + add_node = consumer; + } else if (IsLayerNormNode(*consumer)) { + other_node = consumer; + } + } + + return std::make_tuple(add_node != nullptr && other_node != nullptr, add_node, other_node); +} +} // namespace + +/* + One classical layer includes 1). input layer norm, 2). attention, 3). residual add + (input layer norm input + attention output), 4). post attention layer norm feedforward, and 5). residual add + (post attention layer norm input + feedforward out). + + The pattern graph looks like below for each transformer layer (taking the example of MistralDecoderLayer): + | + Embedding + | + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization (layer boudary node) + | | + | | + | MistralAttention + | | + | | + |____________________Add + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization + | | + | | + | MultipleLayerPerception + | | + | | + |____________________Add + | + (new layer) + ----------------------| + | | + | | + | SimplifiedLayerNormalization + .... + +*/ void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, - const logging::Logger&, + const logging::Logger& logger, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, @@ -27,40 +109,46 @@ void FindLayerBoundaryLayerNormNodes( // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. // If the found node is another LayerNormalization, the LayerNormalization node as MLP. - const InlinedHashSet softmax_ops{ - "Softmax", - "BiasSoftmax", - }; - const InlinedHashSet layernorm_ops{ - "LayerNormalization", - "SkipLayerNormalization", - "SimplifiedLayerNormalization", - "SkipSimplifiedLayerNormalization", - }; layer_boundary_ln_nodes.clear(); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (auto node_index : node_topology_list) { auto& node = *graph_viewer.GetNode(node_index); - if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) { + if (!IsLayerNormNode(node)) { + continue; + } + const NodeArg* input_arg = node.InputDefs()[0]; + + // IsResidualNodeArg checks input_arg + auto [is_residual_node_arg, add_node, other_node] = IsResidualNodeArg(graph_viewer, input_arg); + if (!is_residual_node_arg) { + MO_LOG_DEBUG_INFO(logger, "Not a residual node arg " + input_arg->Name()); continue; } + // At this point, there should not be any recompute node, so we don't need check the node existence in + // node_index_to_its_order_in_topological_sort_map. + ptrdiff_t attention_residual_add_node_order = node_index_to_its_order_in_topological_sort_map.at(add_node->Index()); + ptrdiff_t attention_residual_ln_node_order = node_index_to_its_order_in_topological_sort_map.at(other_node->Index()); + if (attention_residual_add_node_order >= yield_op_order_in_topological_sort || + attention_residual_ln_node_order >= yield_op_order_in_topological_sort) { + MO_LOG_DEBUG_INFO(logger, "Not a valid residual node arg " + input_arg->Name()); + continue; + } + + // Search all forward nodes that is before `add_node` in topo order, unless we find a softmax or nodes_to_check is empty. std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { // Ignore those nodes after YieldOp. - if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { nodes_to_check.push_back(&(*node_it)); } } - // For a perfect layer match, all those three flags should be true after exiting while loop. - // except the layernorm after the last layer, which expects no more softmax and layernorm are found. - bool found_softmax = false; - bool found_mid_layernorm = false; - bool found_end_layernorm = false; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -70,36 +158,19 @@ void FindLayerBoundaryLayerNormNodes( } visited_nodes.insert(next_node); - if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - found_softmax = true; - } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - // mid layernorm MUST appear before end layernorm, because it is a single in and out connector (despite of its weights) - if (found_mid_layernorm) { - found_end_layernorm = true; - break; // exit the while loop since we found the end LayerNormalization node. - } else { - found_mid_layernorm = true; - } + if (IsSoftmaxNode(*next_node)) { + MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " + + input_arg->Name()); + layer_boundary_ln_nodes.push_back(&node); + break; } for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { - // Stop if the node is after YieldOp. - if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= yield_op_order_in_topological_sort) { - continue; + // Stop if the node is after next Layernorm node in execution order. + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { + nodes_to_check.push_back(&(*node_it)); } - nodes_to_check.push_back(&(*node_it)); - } - } - - if (found_softmax && found_mid_layernorm && found_end_layernorm) { - if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { - layer_boundary_ln_nodes.push_back(&node); - } - } else if (!found_softmax && !found_mid_layernorm && !found_end_layernorm) { - // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, - // we also consider it as boundary node. - if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), &node) == layer_boundary_ln_nodes.end()) { - layer_boundary_ln_nodes.push_back(&node); } } } diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 32304d4ffd828..8604cef333d41 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -347,11 +347,10 @@ TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) { yield_op_order_in_topological_sort, layer_boundary_ln_node); - ASSERT_TRUE(layer_boundary_ln_node.size() == 4); + ASSERT_TRUE(layer_boundary_ln_node.size() == 3); ASSERT_TRUE(layer_boundary_ln_node[0]->Name() == "LayerNormalization_token_0"); ASSERT_TRUE(layer_boundary_ln_node[1]->Name() == "LayerNormalization_token_6"); ASSERT_TRUE(layer_boundary_ln_node[2]->Name() == "LayerNormalization_token_12"); - ASSERT_TRUE(layer_boundary_ln_node[3]->Name() == "LayerNormalization_token_18"); } } // namespace test From ee82c2b7418a090ce236c210a802bc5213caee20 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 28 Mar 2024 04:28:19 +0000 Subject: [PATCH 07/18] recompute run with its critical path impact factor --- include/onnxruntime/core/graph/constants.h | 4 +- onnxruntime/core/graph/graph_viewer.cc | 21 ++-- .../memory_optimizer/memory_insight.cc | 46 ------- .../memory_optimizer/memory_insight.h | 10 -- .../memory_optimizer/memory_optimizer.cc | 119 +++++++++++++++++- .../memory_optimizer/recompute_analysis.cc | 59 +++++---- 6 files changed, 164 insertions(+), 95 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 8e04050d089a0..b4f6617734e53 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -56,7 +56,7 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; -// For Priority based graph topology sorting. -constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; +// For memory optimizer's priority based graph topology sorting. +constexpr const char* kRecomputeNodeCriticalPathImpact = "__recompute_critical_path_impact"; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 119d420066a84..f134a3f79434a 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -40,15 +40,18 @@ struct PriorityNodeCompare { } #ifdef ENABLE_TRAINING - // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - if (n1_is_forward != n2_is_forward) { - return n2_is_forward > n1_is_forward; + + // nodes of bigger impact pass will be output first + const auto& n1_attrs = n1->GetAttributes(); + const auto& n2_attrs = n2->GetAttributes(); + int64_t n1_impact = (n1_attrs.find(kRecomputeNodeCriticalPathImpact) != n1_attrs.cend()) + ? static_cast(n1_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + int64_t n2_impact = (n2_attrs.find(kRecomputeNodeCriticalPathImpact) != n2_attrs.cend()) + ? static_cast(n2_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + if (n1_impact != -1 && n2_impact != -1) { + return n2_impact > n1_impact; } #endif diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 3d0fa942fd2d4..3ec784bfb884d 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -171,52 +171,6 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, return Status::OK(); } -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { - // Find the YieldOp node. - Node* yield_op_node = nullptr; - for (auto& node : graph.Nodes()) { - if (node.OpType() == "YieldOp") { - yield_op_node = &node; - break; - } - } - - if (yield_op_node == nullptr) { - return Status::OK(); - } - - // Reverse BFS from YieldOp to find all "forward" nodes. - std::vector fw_nodes; - std::vector end_nodes{yield_op_node}; - graph.ReverseDFSFrom( - end_nodes, - nullptr, - [&fw_nodes](const Node* n) { - fw_nodes.push_back(n); - }, - nullptr); - - // Set the attribute to true for all backward nodes. - for (auto& node : graph.Nodes()) { - if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - continue; - } - node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); - modified = true; - } else { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - node.ClearAttribute(kBackwardNodeAttributeName); - modified = true; - } - } - } - - return Status::OK(); -} - Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, const ProbeConfig& probe_config, const logging::Logger& logger, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h index 3f0a1a9a96f88..ca1df0633eb8f 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -57,16 +57,6 @@ class MemoryRecord { int freq = 0; }; -/** - * @brief Reset `__backwardpass` attribute for all backward nodes in the graph. - * `__backwardpass` is used by Priority-Based topology sorting. - * - * @param graph To be scanned and modified. - * @param modified Whether the graph is modified. - * @return Status - */ -Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified); - /** * @brief Iterate the graph and find all possible memory optimization opportunities for related nodes. * diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index ac619bdc390d3..f7872258e4f37 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -146,9 +146,6 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { - // Reset the backward pass attribute for all nodes. - ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ResetNodeBackwardPassAttribute(graph, modified)); - LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " << static_cast(recompute_probe_config_.probe_level) << ", enable_transformer_layer_as_boundary:" @@ -256,6 +253,122 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve PrintSummary(memory_opt_planner, node_to_apply_context_map, logger); + if (modified) { + ORT_ENFORCE(graph.Resolve().IsOK()); + } + + auto is_recompute_node = [](const Node* n) { + std::string_view name1 = n->Name(); + constexpr std::string_view recompute_suffix = "_recompute"; + if (name1.size() < recompute_suffix.size()) { + return false; + } + + return name1.compare(name1.size() - recompute_suffix.size(), recompute_suffix.size(), recompute_suffix) == 0; + }; + + // Loop through all the nodes in the graph and find the recompute nodes, categorize them into two groups: + // group 1: recompute nodes that are used only by other non-recompute nodes. + // group 2: recompute nodes that are used by some(>=0) non-recompute nodes and some(>=1) recompute nodes + InlinedHashMap> group1_recompute_nodes; + InlinedHashMap, InlinedVector>> group2_recompute_nodes; + InlinedHashMap recompute_node_to_its_critical_path_node_order_map; + for (const auto& n1 : graph.Nodes()) { + if (is_recompute_node(&n1)) { + InlinedVector non_recompute_consumers; + InlinedVector recompute_consumers; + for (auto o_iter = n1.OutputEdgesBegin(); o_iter != n1.OutputEdgesEnd(); ++o_iter) { + const auto& output = *o_iter; + const Node* consumer = graph.GetNode(output.GetNode().Index()); + if (is_recompute_node(consumer)) { + recompute_consumers.push_back(consumer); + } else { + non_recompute_consumers.push_back(consumer); + } + } + + if (recompute_consumers.empty()) { + group1_recompute_nodes.insert({&n1, non_recompute_consumers}); + } else { + group2_recompute_nodes.insert({&n1, {non_recompute_consumers, recompute_consumers}}); + } + } + } + + // Loop group1_recompute_nodes, get the minimal value of execution order from node_index_to_its_order_in_topological_sort_map for its output nodes. + for (const auto& [non_recompute_node, non_recompute_consumers] : group1_recompute_nodes) { + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + if (it != node_index_to_its_order_in_topological_sort_map.end()) { + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + } + + std::cout << ">>>Recompute node: " << non_recompute_node->Name() << " max_impact: " << max_impact << std::endl; + + recompute_node_to_its_critical_path_node_order_map.insert({non_recompute_node, max_impact}); + } + + // Then at this point, all "boudary" recompute nodes are marked for its critical path node order. + // Next, loop group2_recompute_nodes, 1). for each recompute node's non-recompute consumers, find the minimal value of + // execution order from node_index_to_its_order_in_topological_sort_map; 2). for each recompute node's recompute consumers, + // find the minimal value of execution order from recompute_node_to_its_critical_path_node_order_map. + // Loop the node in reversed topological order. + GraphViewer updated_graph_viewer(graph); + const auto& updated_node_ids = updated_graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + for (int i = static_cast(updated_node_ids.size()) - 1; i >= 0; --i) { + Node* p_node = graph.GetNode(updated_node_ids[i]); + + if (p_node == nullptr) { + continue; + } + + if (group2_recompute_nodes.find(p_node) != group2_recompute_nodes.end()) { + const auto& [non_recompute_consumers, recompute_consumers] = group2_recompute_nodes.at(p_node); + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), "Cannot find the order for non-recompute consumer node: ", consumer->Name()); + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + + for (const auto& consumer : recompute_consumers) { + auto it = recompute_node_to_its_critical_path_node_order_map.find(consumer); + ORT_ENFORCE(it != recompute_node_to_its_critical_path_node_order_map.end(), "Cannot find the critical path order for recompute node: ", consumer->Name()); + max_impact = std::max(max_impact, it->second); + } + + std::cout << ">>>!!!Recompute node: " << p_node->Name() << " max_impact: " << max_impact << std::endl; + + recompute_node_to_its_critical_path_node_order_map.insert({p_node, max_impact}); + } + } + + // Finally, loop through recompute_node_to_its_critical_path_node_order_map, add attribute "__critical_execution_order" + // for each recompute node, which will be used for priority based graph ordering. + for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { + Node* mutable_node = graph.GetNode(recompute_node->Index()); + mutable_node->AddAttribute(kRecomputeNodeCriticalPathImpact, static_cast(order)); + modified = true; + } + + // sort the recompute nodes based on the critical path order and print to stdout + // std::vector> + // recompute_nodes_sorted_by_critical_path_order; + // for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { + // recompute_nodes_sorted_by_critical_path_order.push_back({recompute_node->Name(), order}); + // } + // std::sort(recompute_nodes_sorted_by_critical_path_order.begin(), recompute_nodes_sorted_by_critical_path_order.end(), + // [](const std::pair& a, const std::pair& b) { + // return a.second < b.second; + // }); + // for (const auto& [recompute_node_name, order] : recompute_nodes_sorted_by_critical_path_order) { + // std::cout << ">>>" << order << "Recompute node: " << recompute_node_name << std::endl; + // } + return Status::OK(); } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 37ac1c4950ecd..d11403b99300e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -144,6 +144,20 @@ const InlinedHashMap& GetAllowedRecompu {20, {0}}, }, }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, { utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), { @@ -162,27 +176,6 @@ const InlinedHashMap& GetAllowedRecompu {14, {}}, }, }, - { - utils::GetFullQualifiedOpName("Expand", kOnnxDomain), - { - {8, {1}}, // Ignore the shape. - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Cos", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), - { - // The axis input is trivial - {11, {1}}, - {14, {1}}, - }, - }, { utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), { @@ -199,6 +192,13 @@ const InlinedHashMap& GetAllowedRecompu {19, {}}, }, }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, { utils::GetFullQualifiedOpName("FastGelu", kMSDomain), { @@ -244,6 +244,14 @@ const InlinedHashMap& GetAllowedRecompu {14, {}}, }, }, + { + utils::GetFullQualifiedOpName("Neg", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {13, {}}, + }, + }, { utils::GetFullQualifiedOpName("Range", kOnnxDomain), { @@ -733,10 +741,11 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap layer_boundary_ln_nodes.end()) { int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); if (dest_in_index == 0) { - MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() + - ") is a Attention+MLP layer boundary node, " + - "its stashed activation outputs are used by LayerNormalization's inputs, " + - "we don't need to recompute it."); + std::cout << "Node " + node.Name() + "(" + node.OpType() + + ") is a Attention+MLP layer boundary node, " + + "its stashed activation outputs are used by LayerNormalization's inputs, " + + "we don't need to recompute it." + << std::endl; return nullptr; } } From 19f83be48b5f8486670cdbe953cc841b5bd511f6 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 28 Mar 2024 04:57:10 +0000 Subject: [PATCH 08/18] refine code structure a bit --- .../core/graph/recompute_graph_utils.h | 4 +- .../memory_optimizer/memory_optimizer.cc | 156 +----------------- .../memory_optimizer/recompute_analysis.cc | 113 +++++++++++++ .../memory_optimizer/recompute_analysis.h | 13 ++ .../test/optimizer/memory_optimizer_test.cc | 12 +- 5 files changed, 141 insertions(+), 157 deletions(-) diff --git a/orttraining/orttraining/core/graph/recompute_graph_utils.h b/orttraining/orttraining/core/graph/recompute_graph_utils.h index f4d7e88a072f5..1bf2e11cdd27d 100644 --- a/orttraining/orttraining/core/graph/recompute_graph_utils.h +++ b/orttraining/orttraining/core/graph/recompute_graph_utils.h @@ -6,8 +6,10 @@ namespace onnxruntime { namespace graph_utils { +constexpr const char* kRecomputeFlag = "_recompute"; + inline std::string RecomputeName(const std::string& name) { - return name + "_recompute"; + return name + kRecomputeFlag; } } // namespace graph_utils diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index f7872258e4f37..7d33332062d82 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,40 +190,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - // - // Note 2: Here we use default typo order (which tries to BFS from the outputs, - // so the nearest node to graph output will be visited last). So in reversed default typo order, - // the neareast node to graph output will be visited first. - // Imagine there is a such subgraph - // input1 input2 input3 - // \ | / - // multiple layers - // | - // node M - // labels-------|----- - // \ | | - // node1 | | - // \ | | - // node2 / | - // \ / | - // node loss / - // | / - // YieldOp node1_recompute - // | / - // \ node2 recompute - // \ / - // node loss_grad - // | - // critical grad path - // - // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added - // at last because we do this following reversed topological order. Then node1_recompute node will have lowest - // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then - // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. - // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes - // in this case. - - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { @@ -248,127 +215,14 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } if (recomputed_node_count > 0) { + ORT_ENFORCE(optimizer::memory_optimizer::SetCriticalPathImpact( + graph, node_index_to_its_order_in_topological_sort_map), + "Failed to set critical path impact attribute."); LOGS(logger, INFO) << "Total number of recomputed nodes: " << recomputed_node_count; } PrintSummary(memory_opt_planner, node_to_apply_context_map, logger); - if (modified) { - ORT_ENFORCE(graph.Resolve().IsOK()); - } - - auto is_recompute_node = [](const Node* n) { - std::string_view name1 = n->Name(); - constexpr std::string_view recompute_suffix = "_recompute"; - if (name1.size() < recompute_suffix.size()) { - return false; - } - - return name1.compare(name1.size() - recompute_suffix.size(), recompute_suffix.size(), recompute_suffix) == 0; - }; - - // Loop through all the nodes in the graph and find the recompute nodes, categorize them into two groups: - // group 1: recompute nodes that are used only by other non-recompute nodes. - // group 2: recompute nodes that are used by some(>=0) non-recompute nodes and some(>=1) recompute nodes - InlinedHashMap> group1_recompute_nodes; - InlinedHashMap, InlinedVector>> group2_recompute_nodes; - InlinedHashMap recompute_node_to_its_critical_path_node_order_map; - for (const auto& n1 : graph.Nodes()) { - if (is_recompute_node(&n1)) { - InlinedVector non_recompute_consumers; - InlinedVector recompute_consumers; - for (auto o_iter = n1.OutputEdgesBegin(); o_iter != n1.OutputEdgesEnd(); ++o_iter) { - const auto& output = *o_iter; - const Node* consumer = graph.GetNode(output.GetNode().Index()); - if (is_recompute_node(consumer)) { - recompute_consumers.push_back(consumer); - } else { - non_recompute_consumers.push_back(consumer); - } - } - - if (recompute_consumers.empty()) { - group1_recompute_nodes.insert({&n1, non_recompute_consumers}); - } else { - group2_recompute_nodes.insert({&n1, {non_recompute_consumers, recompute_consumers}}); - } - } - } - - // Loop group1_recompute_nodes, get the minimal value of execution order from node_index_to_its_order_in_topological_sort_map for its output nodes. - for (const auto& [non_recompute_node, non_recompute_consumers] : group1_recompute_nodes) { - int64_t max_impact = 0; - for (const auto& consumer : non_recompute_consumers) { - auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); - if (it != node_index_to_its_order_in_topological_sort_map.end()) { - // The smaller the order, then the bigger impact it has. - max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); - } - } - - std::cout << ">>>Recompute node: " << non_recompute_node->Name() << " max_impact: " << max_impact << std::endl; - - recompute_node_to_its_critical_path_node_order_map.insert({non_recompute_node, max_impact}); - } - - // Then at this point, all "boudary" recompute nodes are marked for its critical path node order. - // Next, loop group2_recompute_nodes, 1). for each recompute node's non-recompute consumers, find the minimal value of - // execution order from node_index_to_its_order_in_topological_sort_map; 2). for each recompute node's recompute consumers, - // find the minimal value of execution order from recompute_node_to_its_critical_path_node_order_map. - // Loop the node in reversed topological order. - GraphViewer updated_graph_viewer(graph); - const auto& updated_node_ids = updated_graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); - for (int i = static_cast(updated_node_ids.size()) - 1; i >= 0; --i) { - Node* p_node = graph.GetNode(updated_node_ids[i]); - - if (p_node == nullptr) { - continue; - } - - if (group2_recompute_nodes.find(p_node) != group2_recompute_nodes.end()) { - const auto& [non_recompute_consumers, recompute_consumers] = group2_recompute_nodes.at(p_node); - int64_t max_impact = 0; - for (const auto& consumer : non_recompute_consumers) { - auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); - ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), "Cannot find the order for non-recompute consumer node: ", consumer->Name()); - // The smaller the order, then the bigger impact it has. - max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); - } - - for (const auto& consumer : recompute_consumers) { - auto it = recompute_node_to_its_critical_path_node_order_map.find(consumer); - ORT_ENFORCE(it != recompute_node_to_its_critical_path_node_order_map.end(), "Cannot find the critical path order for recompute node: ", consumer->Name()); - max_impact = std::max(max_impact, it->second); - } - - std::cout << ">>>!!!Recompute node: " << p_node->Name() << " max_impact: " << max_impact << std::endl; - - recompute_node_to_its_critical_path_node_order_map.insert({p_node, max_impact}); - } - } - - // Finally, loop through recompute_node_to_its_critical_path_node_order_map, add attribute "__critical_execution_order" - // for each recompute node, which will be used for priority based graph ordering. - for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { - Node* mutable_node = graph.GetNode(recompute_node->Index()); - mutable_node->AddAttribute(kRecomputeNodeCriticalPathImpact, static_cast(order)); - modified = true; - } - - // sort the recompute nodes based on the critical path order and print to stdout - // std::vector> - // recompute_nodes_sorted_by_critical_path_order; - // for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { - // recompute_nodes_sorted_by_critical_path_order.push_back({recompute_node->Name(), order}); - // } - // std::sort(recompute_nodes_sorted_by_critical_path_order.begin(), recompute_nodes_sorted_by_critical_path_order.end(), - // [](const std::pair& a, const std::pair& b) { - // return a.second < b.second; - // }); - // for (const auto& [recompute_node_name, order] : recompute_nodes_sorted_by_critical_path_order) { - // std::cout << ">>>" << order << "Recompute node: " << recompute_node_name << std::endl; - // } - return Status::OK(); } @@ -432,7 +286,7 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, self_contained_outputs_map[output] = new_output_args.back(); } - Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + "_recompute", + Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + graph_utils::kRecomputeFlag, node_to_duplicate->OpType(), "Recompute of " + node_to_duplicate->Name(), new_input_args, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index d11403b99300e..936bad3b87107 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -11,6 +11,7 @@ #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "orttraining/core/graph/recompute_graph_utils.h" #include "core/common/string_utils.h" #include "core/framework/data_types.h" #include "core/optimizer/utils.h" @@ -684,6 +685,20 @@ void NodesInTopoOrderToString(gsl::span nodes_in_topological_ } } +/** + * @brief Check whether a node is a recompute node added by MemoryOptimizer or not. + */ +bool IsRecomputeNode(const Node* n) { + static size_t recompute_suffix_len = std::string(graph_utils::kRecomputeFlag).size(); + + std::string_view name = n->Name(); + if (name.size() < recompute_suffix_len) { + return false; + } + + return name.compare(name.size() - recompute_suffix_len, recompute_suffix_len, graph_utils::kRecomputeFlag) == 0; +} + } // namespace Status ParseProbeConfigFromString(std::string_view recompute_probe_config, ProbeConfig& probe_config) { @@ -806,4 +821,102 @@ std::string NodeRecomputePlan::GetNodesInTopoOrderStr() const { return subgraph_str_representation; } +bool SetCriticalPathImpact(Graph& graph, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map) { + // Loop through all the nodes in the graph and find the recompute nodes, categorize them into two groups: + // group 1: recompute nodes that are used only by other non-recompute nodes. + // group 2: recompute nodes that are used by some(>=0) non-recompute nodes and some(>=1) recompute nodes + InlinedHashMap> group1_recompute_nodes; + InlinedHashMap, InlinedVector>> group2_recompute_nodes; + InlinedHashMap recompute_node_to_its_critical_path_node_order_map; + for (const auto& n1 : graph.Nodes()) { + if (IsRecomputeNode(&n1)) { + InlinedVector non_recompute_consumers; + InlinedVector recompute_consumers; + for (auto o_iter = n1.OutputEdgesBegin(); o_iter != n1.OutputEdgesEnd(); ++o_iter) { + const auto& output = *o_iter; + const Node* consumer = graph.GetNode(output.GetNode().Index()); + if (IsRecomputeNode(consumer)) { + recompute_consumers.push_back(consumer); + } else { + non_recompute_consumers.push_back(consumer); + } + } + + if (recompute_consumers.empty()) { + group1_recompute_nodes.insert({&n1, non_recompute_consumers}); + } else { + group2_recompute_nodes.insert({&n1, {non_recompute_consumers, recompute_consumers}}); + } + } + } + + // Loop group1_recompute_nodes, get the minimal value of execution order + // from node_index_to_its_order_in_topological_sort_map for its output nodes. + for (const auto& [non_recompute_node, non_recompute_consumers] : group1_recompute_nodes) { + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), + "Cannot find the order for non-recompute consumer node: ", consumer->Name()); + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + + recompute_node_to_its_critical_path_node_order_map.insert({non_recompute_node, max_impact}); + } + + // Then at this point, all "boudary" recompute nodes are marked for its critical path node order. + // Next, loop group2_recompute_nodes: + // 1). for each recompute node's non-recompute consumers, find the minimal value of + // execution order from node_index_to_its_order_in_topological_sort_map; + // 2). for each recompute node's recompute consumers, find the minimal value of execution order from + // recompute_node_to_its_critical_path_node_order_map. + // + // Be noted, we loop the node in reversed topological order, to make sure "boudary" recompute nodes' + // parent recompute nodes are processed first, then those parent recompute nodes' parent recompute nodes. + GraphViewer updated_graph_viewer(graph); + const auto& updated_node_ids = updated_graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + for (int i = static_cast(updated_node_ids.size()) - 1; i >= 0; --i) { + Node* p_node = graph.GetNode(updated_node_ids[i]); + + if (p_node == nullptr) { + continue; + } + + if (group2_recompute_nodes.find(p_node) != group2_recompute_nodes.end()) { + const auto& [non_recompute_consumers, recompute_consumers] = group2_recompute_nodes.at(p_node); + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), + "Cannot find the order for non-recompute consumer node: ", consumer->Name()); + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + + for (const auto& consumer : recompute_consumers) { + auto it = recompute_node_to_its_critical_path_node_order_map.find(consumer); + ORT_ENFORCE(it != recompute_node_to_its_critical_path_node_order_map.end(), + "Cannot find the critical path order for recompute node: ", consumer->Name()); + max_impact = std::max(max_impact, it->second); + } + + recompute_node_to_its_critical_path_node_order_map.insert({p_node, max_impact}); + } + } + + // Finally, loop through recompute_node_to_its_critical_path_node_order_map, add the attribute + // for each recompute node, which will be used for priority based graph ordering. + bool modified = false; + for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { + Node* mutable_node = graph.GetNode(recompute_node->Index()); + mutable_node->AddAttribute(kRecomputeNodeCriticalPathImpact, static_cast(order)); + modified = true; + } + + return modified; +} + } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ac1021f5eb83b..319829fdec7a0 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -169,4 +169,17 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap bool compromise_stashed_activation, bool& can_compromise_stashed_activation); +/** + * @brief Set the critical path impact for recompute nodes as an attribute. + * The impact is a int64_t value, which will be respected during Priority-Based topology sort. + * The bigger it is, the earlier the node will be executed. + * + * @param graph The graph to iterate and update. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * @return true if the graph is modified, false otherwise. + */ +bool SetCriticalPathImpact(Graph& graph, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map); + } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 8604cef333d41..08fdbbdd68f3d 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -18,6 +18,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" + #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/session/inference_session.h" @@ -26,6 +27,7 @@ #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "orttraining/core/graph/recompute_graph_utils.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" @@ -223,20 +225,20 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { for (auto& consumer : consumers) { if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { + if (consumer->Name().find(graph_utils::kRecomputeFlag) != std::string::npos) { recompute_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node != nullptr); ASSERT_EQ(recompute_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); + ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find(graph_utils::kRecomputeFlag) == std::string::npos); } else { original_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); original_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(original_ln_node_parent_add_or_ln_node); ASSERT_EQ(original_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); + ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find(graph_utils::kRecomputeFlag) == std::string::npos); } } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { input_layer_norm_grad_node = consumer; @@ -262,14 +264,14 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { for (auto& consumer : consumers) { if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { + if (consumer->Name().find(graph_utils::kRecomputeFlag) != std::string::npos) { recompute_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_node); ASSERT_EQ(recompute_ln_node_parent_add_node->OpType(), "Add"); ASSERT_EQ(recompute_ln_node_parent_add_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find("_recompute") != std::string::npos); + ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find(graph_utils::kRecomputeFlag) != std::string::npos); } else { original_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); From 02c225ea3ffde986528ace3f2d9f919573fa6575 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 28 Mar 2024 05:10:38 +0000 Subject: [PATCH 09/18] make padding removal work with memory recompute --- .../optimizer/memory_optimizer/recompute_analysis.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 936bad3b87107..2dfcc8d6207be 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -214,6 +214,12 @@ const InlinedHashMap& GetAllowedRecompu {13, {1}}, }, }, + { + utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), + { + {1, {1}}, // ignore the indices + }, + }, { utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), { @@ -253,6 +259,12 @@ const InlinedHashMap& GetAllowedRecompu {13, {}}, }, }, + { + utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), + { + {1, {1, 2}}, // ignore the indices and unflatten_dims + }, + }, { utils::GetFullQualifiedOpName("Range", kOnnxDomain), { From e5e30ff2b4093c39d875a044f2da4d1e2adb5d90 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Thu, 28 Mar 2024 02:37:40 -0700 Subject: [PATCH 10/18] refine codes --- include/onnxruntime/core/graph/constants.h | 6 +- onnxruntime/core/graph/graph_viewer.cc | 64 +++++++++++++++---- .../core/optimizer/memory_optimizer/common.h | 2 +- .../memory_optimizer/memory_insight.cc | 46 +++++++++++++ .../memory_optimizer/memory_insight.h | 10 +++ .../memory_optimizer/memory_optimizer.cc | 35 ++++++++++ 6 files changed, 149 insertions(+), 14 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index b4f6617734e53..0a06a350bb774 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,8 +55,10 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; - -// For memory optimizer's priority based graph topology sorting. +#ifdef ENABLE_TRAINING +// For priority based graph topology sorting. +constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; constexpr const char* kRecomputeNodeCriticalPathImpact = "__recompute_critical_path_impact"; +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index f134a3f79434a..9945baeefa09d 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -41,18 +41,60 @@ struct PriorityNodeCompare { #ifdef ENABLE_TRAINING - // nodes of bigger impact pass will be output first - const auto& n1_attrs = n1->GetAttributes(); - const auto& n2_attrs = n2->GetAttributes(); - int64_t n1_impact = (n1_attrs.find(kRecomputeNodeCriticalPathImpact) != n1_attrs.cend()) - ? static_cast(n1_attrs.at(kRecomputeNodeCriticalPathImpact).i()) - : -1; - int64_t n2_impact = (n2_attrs.find(kRecomputeNodeCriticalPathImpact) != n2_attrs.cend()) - ? static_cast(n2_attrs.at(kRecomputeNodeCriticalPathImpact).i()) - : -1; - if (n1_impact != -1 && n2_impact != -1) { - return n2_impact > n1_impact; + // Sorting factors for training scenarios. + if (n1_priority == static_cast(ExecutionPriority::DEFAULT)) { + // If both nodes are normal, prioritize outputting the forward pass node. + // + // Note 1: This preference arises from producer-consumer node pairs not separated by "YieldOp". + // The producer (forward pass, contributing to YieldOp inputs) and consumer (backward pass, + // used for gradient computation) should output in forward order to save memory. + // + // Note 2: MemoryOptimizer marks nodes as forward by backtracking from YieldOp's inputs. + // Nodes reached by this backtracking, identified through their inputs, are tagged as forward. + // + // The nodes of forward pass will be output first + auto n1_attrs = n1->GetAttributes(); + auto n2_attrs = n2->GetAttributes(); + int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || + (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || + (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + if (n1_is_forward != n2_is_forward) { + return n2_is_forward > n1_is_forward; + } + } else if (n1_priority == static_cast(ExecutionPriority::LOCAL_LOW)) { + // If both are low priority nodes, we prefer to output nodes with bigger impact first. + // Only recompute scenarios will set the critical path impact attribute. + // + // Note 1: Importance of Critical Path Impact in Topological Sorting + // In recompute scenarios, it's crucial to identify which node to execute to unblock the + // critical path. This ensures nodes in the critical path are executed without delay. + // For more details, refer to MemoryOptimizer's implementation. + // + // Note 2: Defining Critical Path Impact + // Critical path impact is a value set during MemoryOptimizer's operation to prioritize + // node execution. It's calculated based on the topological order of nodes and their + // dependencies, ensuring timely execution of critical nodes. For more details, refer + // to MemoryOptimizer's implementation. + // + // Note 3: This trick is not necessarily bound to LOCAL_LOW priority nodes, but we are using it for + // recompue in MemoryOptimizer, so we add the check there. Feel free to revisit the check if it is + // useful for other priorities. + // + // The nodes of bigger impact pass will be output first + const auto& n1_attrs = n1->GetAttributes(); + const auto& n2_attrs = n2->GetAttributes(); + int64_t n1_impact = (n1_attrs.find(kRecomputeNodeCriticalPathImpact) != n1_attrs.cend()) + ? static_cast(n1_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + int64_t n2_impact = (n2_attrs.find(kRecomputeNodeCriticalPathImpact) != n2_attrs.cend()) + ? static_cast(n2_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + if (n1_impact != -1 && n2_impact != -1) { + return n2_impact > n1_impact; + } } + #endif // otherwise, nodes with lower index will be output first diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h index 268ed84f7a85f..560cd88f18265 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -18,7 +18,7 @@ namespace onnxruntime::optimizer::memory_optimizer { // Uncomment for debugging Memory optimizer (MO). -// #define MO_NEED_LOG_DEBUG_INFO 1 +#define MO_NEED_LOG_DEBUG_INFO 1 #ifndef MO_LOG_DEBUG_INFO #ifdef MO_NEED_LOG_DEBUG_INFO diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 3ec784bfb884d..3d0fa942fd2d4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -171,6 +171,52 @@ Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, return Status::OK(); } +Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { + // Find the YieldOp node. + Node* yield_op_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "YieldOp") { + yield_op_node = &node; + break; + } + } + + if (yield_op_node == nullptr) { + return Status::OK(); + } + + // Reverse BFS from YieldOp to find all "forward" nodes. + std::vector fw_nodes; + std::vector end_nodes{yield_op_node}; + graph.ReverseDFSFrom( + end_nodes, + nullptr, + [&fw_nodes](const Node* n) { + fw_nodes.push_back(n); + }, + nullptr); + + // Set the attribute to true for all backward nodes. + for (auto& node : graph.Nodes()) { + if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { + auto& attrs = node.GetAttributes(); + if (attrs.count(kBackwardNodeAttributeName)) { + continue; + } + node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); + modified = true; + } else { + auto& attrs = node.GetAttributes(); + if (attrs.count(kBackwardNodeAttributeName)) { + node.ClearAttribute(kBackwardNodeAttributeName); + modified = true; + } + } + } + + return Status::OK(); +} + Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, const ProbeConfig& probe_config, const logging::Logger& logger, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h index ca1df0633eb8f..3f0a1a9a96f88 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -57,6 +57,16 @@ class MemoryRecord { int freq = 0; }; +/** + * @brief Reset `__backwardpass` attribute for all backward nodes in the graph. + * `__backwardpass` is used by Priority-Based topology sorting. + * + * @param graph To be scanned and modified. + * @param modified Whether the graph is modified. + * @return Status + */ +Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified); + /** * @brief Iterate the graph and find all possible memory optimization opportunities for related nodes. * diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 7d33332062d82..2d3dd495cfa11 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -146,6 +146,9 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { + // Reset the backward pass attribute for all nodes. + ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ResetNodeBackwardPassAttribute(graph, modified)); + LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " << static_cast(recompute_probe_config_.probe_level) << ", enable_transformer_layer_as_boundary:" @@ -215,6 +218,38 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } if (recomputed_node_count > 0) { + // Note 1: Critical Path Impact in Priority-Based Topological Sort + // + // Setting and comparing critical path impact is essential in scenarios where all nodes in the priority queue + // are of low priority. This comparison helps determine which recompute node to select to unblock the backward + // critical path. Consider a scenario where: + // - A recompute node subgraph, NodeSubgraph-L, exists within transformer layer N (e.g., NodeSubgraph-5 in + // layer 5, NodeSubgraph-3 in layer 3). + // - Node-A-IN-5 within NodeSubgraph-5 depends on Node-B-IN-0 within NodeSubgraph-0. + // - The priority queue contains nodes from NodeSubgraph-0 to NodeSubgraph-5. + // In MemoryOptimizer recompute scenarios, we append nodes starting from NodeSubgraph-5 down to NodeSubgraph-0. + // Relying solely on node index for comparison could lead to: + // 1) Sequential output of nodes in NodeSubgraph-5 (sorted by ascending node index within the subgraph). + // 2) Blocking of Node-A-IN-5's execution until Node-B-IN-0 is executed. + // 3) Sequential output of nodes from NodeSubgraph-4 to NodeSubgraph-1. + // 4) Execution of NodeSubgraph-0 nodes, allowing Node-A-IN-5 and subsequent NodeSubgraph-5 nodes to execute. + // 5) Execution of remaining NodeSubgraph-0 nodes. + // + // This process can significantly delay the execution of Node-A-IN-5, blocking other NodeSubgraph-5 nodes. + // Since NodeSubgraph-5 nodes are on the critical path, triggering their dependencies timely is crucial to + // ensure their execution as early as possible, ahead of other layers. This necessity led to the introduction + // of critical path impact. + // + // Note 2: Defining Critical Path Impact + // Critical path impact is a metric representing a node's influence on the critical path. It is determined + // during MemoryOptimizer's operation as follows: + // 1) Sort graphs without recompute optimization to establish a baseline topological order. + // 2) Apply recompute optimization. + // 3) Identify recompute boundary nodes (recompute nodes not consumed by others). + // 4) For each boundary node, calculate the minimum topological order of all output nodes. + // The minimum value indicates the earliest need for the recompute node's execution. + // We assign std::numeric_limits::max() - min_topological_order as the critical path impact. + // 5) For other recompute nodes, assign the maximum critical path impact of their output nodes. ORT_ENFORCE(optimizer::memory_optimizer::SetCriticalPathImpact( graph, node_index_to_its_order_in_topological_sort_map), "Failed to set critical path impact attribute."); From 2ee7caa738d925f94587c517e0cb738fe7776f7f Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 29 Mar 2024 02:54:43 -0700 Subject: [PATCH 11/18] fix --- .../memory_optimizer/recompute_analysis.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 2dfcc8d6207be..336781bd382b5 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -207,17 +207,17 @@ const InlinedHashMap& GetAllowedRecompu }, }, { - utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), { {1, {1}}, // ignore the indices - {11, {1}}, - {13, {1}}, }, }, { - utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), { {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, }, }, { @@ -232,6 +232,17 @@ const InlinedHashMap& GetAllowedRecompu {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {11, {}}, + {13, {}}, + }, + }, { utils::GetFullQualifiedOpName("Less", kOnnxDomain), { From 35183bf34a8786b09b8390aa2316b77ecd3ad4ad Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 29 Mar 2024 07:09:07 -0700 Subject: [PATCH 12/18] add NonZero --- .../core/optimizer/memory_optimizer/recompute_analysis.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 86379620e775f..a957c42ababd6 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -270,6 +270,13 @@ const InlinedHashMap& GetAllowedRecompu {13, {}}, }, }, + { + utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), + { + {9, {}}, + {13, {}}, + }, + }, { utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), { From c33cb8380f989f919320bec88daccb4bf8917506 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Mon, 1 Apr 2024 22:48:46 -0700 Subject: [PATCH 13/18] cast propogation && gemm transpose fuse --- .../core/optimizer/gemm_transpose_fusion.cc | 1 + .../core/optimizer/graph_transformer_utils.cc | 6 ++++++ onnxruntime/core/optimizer/propagate_cast_ops.cc | 15 ++++++++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index a52517d23db86..613b4bb59ea7a 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -86,6 +86,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m new_gemm_node.AddAttribute("transB", static_cast(transB)); new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f()); new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f()); + new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType()); graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 63612c47f9c56..7e977855c357e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -136,6 +136,7 @@ InlinedVector> GenerateRewriteRules( break; case TransformerLevel::Level2: + rules.push_back(std::make_unique()); // No level2 rules available today break; @@ -251,6 +252,11 @@ InlinedVector> GenerateTransformers( } break; case TransformerLevel::Level2: { + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + if (rule_transformer != nullptr) { + transformers.emplace_back(std::move(rule_transformer)); + } + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index e4f34e066851f..b2a78781d9589 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -171,7 +171,20 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16 using OpsSetType = InlinedHashSet; static const OpsSetType level1_fp16_allow_set = - {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"}; + { + /* Layerout change operators */ + "Expand", + "PadAndUnflatten", + "Reshape", + "Split", + "Squeeze", + "Transpose", + "Unsqueeze", + /* Revisted element-wise operators. */ + "Gelu", + "Relu", + "Tanh", + }; static const OpsSetType level2_fp16_allow_set = { "Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"}; From 3e03d1efcf4d049a1200b25561f661aaee0e3a2c Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 3 Apr 2024 00:38:53 -0700 Subject: [PATCH 14/18] timestamped priprity based ordering --- include/onnxruntime/core/graph/graph.h | 4 +- onnxruntime/core/graph/graph.cc | 19 ++++++--- onnxruntime/core/graph/graph_viewer.cc | 53 ++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index b16d52dbdab68..d82d356210d74 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -54,6 +54,7 @@ class OpSignature; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) class RuntimeOptimizationRecordContainer; +using TimeStampedEntry = std::pair; #endif namespace fbs { @@ -1078,12 +1079,13 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::function& stop) const; #if !defined(ORT_MINIMAL_BUILD) + /** Performs topological sort with Kahn's algorithm on the graph/s. @param enter Visit function that will be invoked on a node when it is visited. @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. */ void KahnsTopologicalSort(const std::function& enter, - const std::function& comp) const; + const std::function& comp) const; #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 305122c56b865..0ddc9a3c8f26c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1820,7 +1820,7 @@ void Graph::ReverseDFSFrom(gsl::span from, template struct VisitorPriorityQueue { - using ComparatorType = std::function; + using ComparatorType = std::function; std::list list_; const ComparatorType comparator_ = nullptr; VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} @@ -1837,10 +1837,16 @@ struct VisitorPriorityQueue { #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, - const std::function& comp) const { + const std::function& comp) const { InlinedVector in_degree(MaxNodeIndex(), 0); InlinedVector topo_order; - VisitorPriorityQueue to_visit(comp); + VisitorPriorityQueue to_visit(comp); + + float time_stamp = 0.0f; + auto get_time_stamp = [&time_stamp]() -> float { + time_stamp += 1; + return time_stamp; + }; auto number_of_nodes = NumberOfNodes(); topo_order.reserve(number_of_nodes); @@ -1849,12 +1855,13 @@ void Graph::KahnsTopologicalSort(const std::function& enter, size_t input_edge_count = node.GetInputEdgesCount(); in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { - to_visit.push(&node); + to_visit.push(&node, get_time_stamp()); } } while (!to_visit.empty()) { - const Node* current = to_visit.top(); + const TimeStampedEntry ts_entry = to_visit.top(); + cosnt Node* current = ts_entry.first; to_visit.pop(); if (!current) continue; @@ -1868,7 +1875,7 @@ void Graph::KahnsTopologicalSort(const std::function& enter, node_in_degree--; if (node_in_degree == 0) { - to_visit.push(&*node_it); + to_visit.push(&*node_it, get_time_stamp()); } } topo_order.push_back(current->Index()); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 9945baeefa09d..01f6b4f9d3814 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -24,7 +24,9 @@ struct PriorityNodeCompare { // Used for std::priority_queue // If return false, n1 will be output first // If return true, n2 will be output first - bool operator()(const Node* n1, const Node* n2) const { + bool operator()(const TimeStampedEntry& entry1, const TimeStampedEntry& entry2) const { + const Node* n1 = entry1.first; + const Node* n2 = entry2.first; // nodes in global high priority list will be output first const bool isN1HighPri = IsHighPri(n1); const bool isN2HighPri = IsHighPri(n2); @@ -95,6 +97,12 @@ struct PriorityNodeCompare { } } + const float ts1 = entry1.second; + const float ts2 = entry2.second; + if (ts1 != ts2) { + return ts1 > ts2; + } + #endif // otherwise, nodes with lower index will be output first @@ -175,11 +183,50 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } #endif #if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_in_topological_order_with_priority; graph.KahnsTopologicalSort( - [this](const Node* n) { - nodes_in_topological_order_with_priority_.push_back(n->Index()); + [&nodes_in_topological_order_with_priority](const Node* n) { + nodes_in_topological_order_with_priority.push_back(n->Index()); }, PriorityNodeCompare()); + + // Tune the order a bit. + // If a node is used by a consumer node, but the execution order of the consumer node is later than the producer node, + // we should move the producer node to right before the consumer node. In this case, the producer node will only be executed + // when needed and the memory can be released earlier. We do this in reversed topological order to hanlde the single-input-single_output + // node chains. + InlinedVector node_in_reversed_order; + node_in_reversed_order.reserve(nodes_in_topological_order_with_priority.size()); + for (auto it = nodes_in_topological_order_with_priority.rbegin(); it != nodes_in_topological_order_with_priority.rend(); ++it) { + const Node* node = graph_->GetNode(*it); + + if (node->GetOutputEdgesCount() != 1) { + // Don't need tune, just add it to the front of reversed_order. + node_in_reversed_order.push_back(node->Index()); + continue; + } + + // Handle the "High priority nodes" differently + // So, it may break the computation order also when recompute is enabled. + // But as ShapeInputMerge is introduced, there is much less chance to let recompute subgraph consumed by a normal Shape + // or Size node. (TODO: pengwa): observe from real models and see if we need to handle this case. + if (node->OpType() == "Shape" || node->OpType() == "Size") { + node_in_reversed_order.push_back(node->Index()); + continue; + } + + const Node* consumer = &(*(node->OutputNodesBegin())); + // Insert the consumer node right after the producer node. (Remember the order is reversed here). + auto it_consumer = std::find(node_in_reversed_order.begin(), node_in_reversed_order.end(), consumer->Index()); + ORT_ENFORCE(it_consumer != node_in_reversed_order.end()); + node_in_reversed_order.insert(it_consumer + 1, node->Index()); // Then node is inserted right after the consumer node. + } + + nodes_in_topological_order_with_priority_.insert( + nodes_in_topological_order_with_priority_.end(), + node_in_reversed_order.rbegin(), + node_in_reversed_order.rend()); + #endif if (filter_info_) { From 4b25ad61c21abe8e70cc951a4231245e94d2ff85 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 3 Apr 2024 01:10:30 -0700 Subject: [PATCH 15/18] Revert "timestamped priprity based ordering" This reverts commit 3e03d1efcf4d049a1200b25561f661aaee0e3a2c. --- include/onnxruntime/core/graph/graph.h | 4 +- onnxruntime/core/graph/graph.cc | 19 +++------ onnxruntime/core/graph/graph_viewer.cc | 53 ++------------------------ 3 files changed, 10 insertions(+), 66 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index d82d356210d74..b16d52dbdab68 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -54,7 +54,6 @@ class OpSignature; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) class RuntimeOptimizationRecordContainer; -using TimeStampedEntry = std::pair; #endif namespace fbs { @@ -1079,13 +1078,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::function& stop) const; #if !defined(ORT_MINIMAL_BUILD) - /** Performs topological sort with Kahn's algorithm on the graph/s. @param enter Visit function that will be invoked on a node when it is visited. @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. */ void KahnsTopologicalSort(const std::function& enter, - const std::function& comp) const; + const std::function& comp) const; #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0ddc9a3c8f26c..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1820,7 +1820,7 @@ void Graph::ReverseDFSFrom(gsl::span from, template struct VisitorPriorityQueue { - using ComparatorType = std::function; + using ComparatorType = std::function; std::list list_; const ComparatorType comparator_ = nullptr; VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} @@ -1837,16 +1837,10 @@ struct VisitorPriorityQueue { #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, - const std::function& comp) const { + const std::function& comp) const { InlinedVector in_degree(MaxNodeIndex(), 0); InlinedVector topo_order; - VisitorPriorityQueue to_visit(comp); - - float time_stamp = 0.0f; - auto get_time_stamp = [&time_stamp]() -> float { - time_stamp += 1; - return time_stamp; - }; + VisitorPriorityQueue to_visit(comp); auto number_of_nodes = NumberOfNodes(); topo_order.reserve(number_of_nodes); @@ -1855,13 +1849,12 @@ void Graph::KahnsTopologicalSort(const std::function& enter, size_t input_edge_count = node.GetInputEdgesCount(); in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { - to_visit.push(&node, get_time_stamp()); + to_visit.push(&node); } } while (!to_visit.empty()) { - const TimeStampedEntry ts_entry = to_visit.top(); - cosnt Node* current = ts_entry.first; + const Node* current = to_visit.top(); to_visit.pop(); if (!current) continue; @@ -1875,7 +1868,7 @@ void Graph::KahnsTopologicalSort(const std::function& enter, node_in_degree--; if (node_in_degree == 0) { - to_visit.push(&*node_it, get_time_stamp()); + to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 01f6b4f9d3814..9945baeefa09d 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -24,9 +24,7 @@ struct PriorityNodeCompare { // Used for std::priority_queue // If return false, n1 will be output first // If return true, n2 will be output first - bool operator()(const TimeStampedEntry& entry1, const TimeStampedEntry& entry2) const { - const Node* n1 = entry1.first; - const Node* n2 = entry2.first; + bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first const bool isN1HighPri = IsHighPri(n1); const bool isN2HighPri = IsHighPri(n2); @@ -97,12 +95,6 @@ struct PriorityNodeCompare { } } - const float ts1 = entry1.second; - const float ts2 = entry2.second; - if (ts1 != ts2) { - return ts1 > ts2; - } - #endif // otherwise, nodes with lower index will be output first @@ -183,50 +175,11 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } #endif #if !defined(ORT_MINIMAL_BUILD) - std::vector nodes_in_topological_order_with_priority; graph.KahnsTopologicalSort( - [&nodes_in_topological_order_with_priority](const Node* n) { - nodes_in_topological_order_with_priority.push_back(n->Index()); + [this](const Node* n) { + nodes_in_topological_order_with_priority_.push_back(n->Index()); }, PriorityNodeCompare()); - - // Tune the order a bit. - // If a node is used by a consumer node, but the execution order of the consumer node is later than the producer node, - // we should move the producer node to right before the consumer node. In this case, the producer node will only be executed - // when needed and the memory can be released earlier. We do this in reversed topological order to hanlde the single-input-single_output - // node chains. - InlinedVector node_in_reversed_order; - node_in_reversed_order.reserve(nodes_in_topological_order_with_priority.size()); - for (auto it = nodes_in_topological_order_with_priority.rbegin(); it != nodes_in_topological_order_with_priority.rend(); ++it) { - const Node* node = graph_->GetNode(*it); - - if (node->GetOutputEdgesCount() != 1) { - // Don't need tune, just add it to the front of reversed_order. - node_in_reversed_order.push_back(node->Index()); - continue; - } - - // Handle the "High priority nodes" differently - // So, it may break the computation order also when recompute is enabled. - // But as ShapeInputMerge is introduced, there is much less chance to let recompute subgraph consumed by a normal Shape - // or Size node. (TODO: pengwa): observe from real models and see if we need to handle this case. - if (node->OpType() == "Shape" || node->OpType() == "Size") { - node_in_reversed_order.push_back(node->Index()); - continue; - } - - const Node* consumer = &(*(node->OutputNodesBegin())); - // Insert the consumer node right after the producer node. (Remember the order is reversed here). - auto it_consumer = std::find(node_in_reversed_order.begin(), node_in_reversed_order.end(), consumer->Index()); - ORT_ENFORCE(it_consumer != node_in_reversed_order.end()); - node_in_reversed_order.insert(it_consumer + 1, node->Index()); // Then node is inserted right after the consumer node. - } - - nodes_in_topological_order_with_priority_.insert( - nodes_in_topological_order_with_priority_.end(), - node_in_reversed_order.rbegin(), - node_in_reversed_order.rend()); - #endif if (filter_info_) { From cabb44aca69560b9be0fa20fb97081a80b563310 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 3 Apr 2024 01:14:41 -0700 Subject: [PATCH 16/18] Tune single-in-single_out-node-chain for delay execution especially for input leaf nodes --- onnxruntime/core/graph/graph_viewer.cc | 55 +++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 9945baeefa09d..59ef5be094ee6 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -175,11 +175,62 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } #endif #if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_in_topological_order_with_priority; graph.KahnsTopologicalSort( - [this](const Node* n) { - nodes_in_topological_order_with_priority_.push_back(n->Index()); + [&nodes_in_topological_order_with_priority](const Node* n) { + nodes_in_topological_order_with_priority.push_back(n->Index()); }, PriorityNodeCompare()); + + // Tune the order a bit. + // If a node is used by a consumer node, but the execution order of the consumer node is later than the producer node, + // we should move the producer node to right before the consumer node. In this case, the producer node will only be executed + // when needed and the memory can be released earlier. We do this in reversed topological order to hanlde the single-input-single_output + // node chains. + InlinedVector node_in_reversed_order; + node_in_reversed_order.reserve(nodes_in_topological_order_with_priority.size()); + for (auto it = nodes_in_topological_order_with_priority.rbegin(); it != nodes_in_topological_order_with_priority.rend(); ++it) { + const Node* node = graph_->GetNode(*it); + + if (node->GetOutputEdgesCount() != 1) { + // Don't need tune, just add it to the front of reversed_order. + node_in_reversed_order.push_back(node->Index()); + continue; + } + + // Handle the "High priority nodes" differently + // So, it may break the computation order also when recompute is enabled. + // But as ShapeInputMerge is introduced, there is much less chance to let recompute subgraph consumed by a normal Shape + // or Size node. (TODO: pengwa): observe from real models and see if we need to handle this case. + if (node->OpType() == "Shape" || node->OpType() == "Size") { + node_in_reversed_order.push_back(node->Index()); + continue; + } + + // If node is PythonOpGrad, and its attribute func_name string does not start with "onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction", + // we skip to make sure the weight accumulation nodes are executed as early as possible (free buffer and unblock subsquent CPU work). + if (node->OpType() == "PythonOpGrad") { + const auto& attrs = node->GetAttributes(); + auto it = attrs.find("func_name"); + ORT_ENFORCE(it != attrs.end()); + if (it->second.s().find("onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction") != std::string::npos) { + std::cout << "Skip node " << node->Name() << " with func_name " << it->second.s() << std::endl; + node_in_reversed_order.push_back(node->Index()); + continue; + } + } + + const Node* consumer = &(*(node->OutputNodesBegin())); + // Insert the consumer node right after the producer node. (Remember the order is reversed here). + auto it_consumer = std::find(node_in_reversed_order.begin(), node_in_reversed_order.end(), consumer->Index()); + ORT_ENFORCE(it_consumer != node_in_reversed_order.end()); + node_in_reversed_order.insert(it_consumer + 1, node->Index()); // Then node is inserted right after the consumer node. + } + + nodes_in_topological_order_with_priority_.insert( + nodes_in_topological_order_with_priority_.end(), + node_in_reversed_order.rbegin(), + node_in_reversed_order.rend()); #endif if (filter_info_) { From 1f197dc8d522dfcf94b5dd8dbea950d7e3e22159 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 3 Apr 2024 01:25:23 -0700 Subject: [PATCH 17/18] remove log --- onnxruntime/core/graph/graph_viewer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 59ef5be094ee6..4f4d5851c99db 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -214,7 +214,6 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) auto it = attrs.find("func_name"); ORT_ENFORCE(it != attrs.end()); if (it->second.s().find("onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction") != std::string::npos) { - std::cout << "Skip node " << node->Name() << " with func_name " << it->second.s() << std::endl; node_in_reversed_order.push_back(node->Index()); continue; } From 150cd6e7c9fc338162fcae38341465892f8565c8 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Wed, 3 Apr 2024 01:36:49 -0700 Subject: [PATCH 18/18] disable logging --- .../orttraining/core/optimizer/memory_optimizer/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h index 560cd88f18265..268ed84f7a85f 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -18,7 +18,7 @@ namespace onnxruntime::optimizer::memory_optimizer { // Uncomment for debugging Memory optimizer (MO). -#define MO_NEED_LOG_DEBUG_INFO 1 +// #define MO_NEED_LOG_DEBUG_INFO 1 #ifndef MO_LOG_DEBUG_INFO #ifdef MO_NEED_LOG_DEBUG_INFO