diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 8e04050d089a0..0a06a350bb774 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,8 +55,10 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; - -// For Priority based graph topology sorting. +#ifdef ENABLE_TRAINING +// For priority based graph topology sorting. constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; +constexpr const char* kRecomputeNodeCriticalPathImpact = "__recompute_critical_path_impact"; +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 119d420066a84..4f4d5851c99db 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -40,16 +40,61 @@ struct PriorityNodeCompare { } #ifdef ENABLE_TRAINING - // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - if (n1_is_forward != n2_is_forward) { - return n2_is_forward > n1_is_forward; + + // Sorting factors for training scenarios. + if (n1_priority == static_cast(ExecutionPriority::DEFAULT)) { + // If both nodes are normal, prioritize outputting the forward pass node. + // + // Note 1: This preference arises from producer-consumer node pairs not separated by "YieldOp". + // The producer (forward pass, contributing to YieldOp inputs) and consumer (backward pass, + // used for gradient computation) should output in forward order to save memory. + // + // Note 2: MemoryOptimizer marks nodes as forward by backtracking from YieldOp's inputs. + // Nodes reached by this backtracking, identified through their inputs, are tagged as forward. + // + // The nodes of forward pass will be output first + auto n1_attrs = n1->GetAttributes(); + auto n2_attrs = n2->GetAttributes(); + int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || + (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || + (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + if (n1_is_forward != n2_is_forward) { + return n2_is_forward > n1_is_forward; + } + } else if (n1_priority == static_cast(ExecutionPriority::LOCAL_LOW)) { + // If both are low priority nodes, we prefer to output nodes with bigger impact first. + // Only recompute scenarios will set the critical path impact attribute. + // + // Note 1: Importance of Critical Path Impact in Topological Sorting + // In recompute scenarios, it's crucial to identify which node to execute to unblock the + // critical path. This ensures nodes in the critical path are executed without delay. + // For more details, refer to MemoryOptimizer's implementation. + // + // Note 2: Defining Critical Path Impact + // Critical path impact is a value set during MemoryOptimizer's operation to prioritize + // node execution. It's calculated based on the topological order of nodes and their + // dependencies, ensuring timely execution of critical nodes. For more details, refer + // to MemoryOptimizer's implementation. + // + // Note 3: This trick is not necessarily bound to LOCAL_LOW priority nodes, but we are using it for + // recompue in MemoryOptimizer, so we add the check there. Feel free to revisit the check if it is + // useful for other priorities. + // + // The nodes of bigger impact pass will be output first + const auto& n1_attrs = n1->GetAttributes(); + const auto& n2_attrs = n2->GetAttributes(); + int64_t n1_impact = (n1_attrs.find(kRecomputeNodeCriticalPathImpact) != n1_attrs.cend()) + ? static_cast(n1_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + int64_t n2_impact = (n2_attrs.find(kRecomputeNodeCriticalPathImpact) != n2_attrs.cend()) + ? static_cast(n2_attrs.at(kRecomputeNodeCriticalPathImpact).i()) + : -1; + if (n1_impact != -1 && n2_impact != -1) { + return n2_impact > n1_impact; + } } + #endif // otherwise, nodes with lower index will be output first @@ -130,11 +175,61 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } #endif #if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_in_topological_order_with_priority; graph.KahnsTopologicalSort( - [this](const Node* n) { - nodes_in_topological_order_with_priority_.push_back(n->Index()); + [&nodes_in_topological_order_with_priority](const Node* n) { + nodes_in_topological_order_with_priority.push_back(n->Index()); }, PriorityNodeCompare()); + + // Tune the order a bit. + // If a node is used by a consumer node, but the execution order of the consumer node is later than the producer node, + // we should move the producer node to right before the consumer node. In this case, the producer node will only be executed + // when needed and the memory can be released earlier. We do this in reversed topological order to hanlde the single-input-single_output + // node chains. + InlinedVector node_in_reversed_order; + node_in_reversed_order.reserve(nodes_in_topological_order_with_priority.size()); + for (auto it = nodes_in_topological_order_with_priority.rbegin(); it != nodes_in_topological_order_with_priority.rend(); ++it) { + const Node* node = graph_->GetNode(*it); + + if (node->GetOutputEdgesCount() != 1) { + // Don't need tune, just add it to the front of reversed_order. + node_in_reversed_order.push_back(node->Index()); + continue; + } + + // Handle the "High priority nodes" differently + // So, it may break the computation order also when recompute is enabled. + // But as ShapeInputMerge is introduced, there is much less chance to let recompute subgraph consumed by a normal Shape + // or Size node. (TODO: pengwa): observe from real models and see if we need to handle this case. + if (node->OpType() == "Shape" || node->OpType() == "Size") { + node_in_reversed_order.push_back(node->Index()); + continue; + } + + // If node is PythonOpGrad, and its attribute func_name string does not start with "onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction", + // we skip to make sure the weight accumulation nodes are executed as early as possible (free buffer and unblock subsquent CPU work). + if (node->OpType() == "PythonOpGrad") { + const auto& attrs = node->GetAttributes(); + auto it = attrs.find("func_name"); + ORT_ENFORCE(it != attrs.end()); + if (it->second.s().find("onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction") != std::string::npos) { + node_in_reversed_order.push_back(node->Index()); + continue; + } + } + + const Node* consumer = &(*(node->OutputNodesBegin())); + // Insert the consumer node right after the producer node. (Remember the order is reversed here). + auto it_consumer = std::find(node_in_reversed_order.begin(), node_in_reversed_order.end(), consumer->Index()); + ORT_ENFORCE(it_consumer != node_in_reversed_order.end()); + node_in_reversed_order.insert(it_consumer + 1, node->Index()); // Then node is inserted right after the consumer node. + } + + nodes_in_topological_order_with_priority_.insert( + nodes_in_topological_order_with_priority_.end(), + node_in_reversed_order.rbegin(), + node_in_reversed_order.rend()); #endif if (filter_info_) { diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index a52517d23db86..613b4bb59ea7a 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -86,6 +86,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m new_gemm_node.AddAttribute("transB", static_cast(transB)); new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f()); new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f()); + new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType()); graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 63612c47f9c56..7e977855c357e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -136,6 +136,7 @@ InlinedVector> GenerateRewriteRules( break; case TransformerLevel::Level2: + rules.push_back(std::make_unique()); // No level2 rules available today break; @@ -251,6 +252,11 @@ InlinedVector> GenerateTransformers( } break; case TransformerLevel::Level2: { + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + if (rule_transformer != nullptr) { + transformers.emplace_back(std::move(rule_transformer)); + } + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index e4f34e066851f..b2a78781d9589 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -171,7 +171,20 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16 using OpsSetType = InlinedHashSet; static const OpsSetType level1_fp16_allow_set = - {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"}; + { + /* Layerout change operators */ + "Expand", + "PadAndUnflatten", + "Reshape", + "Split", + "Squeeze", + "Transpose", + "Unsqueeze", + /* Revisted element-wise operators. */ + "Gelu", + "Relu", + "Tanh", + }; static const OpsSetType level2_fp16_allow_set = { "Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"}; diff --git a/orttraining/orttraining/core/graph/recompute_graph_utils.h b/orttraining/orttraining/core/graph/recompute_graph_utils.h index f4d7e88a072f5..1bf2e11cdd27d 100644 --- a/orttraining/orttraining/core/graph/recompute_graph_utils.h +++ b/orttraining/orttraining/core/graph/recompute_graph_utils.h @@ -6,8 +6,10 @@ namespace onnxruntime { namespace graph_utils { +constexpr const char* kRecomputeFlag = "_recompute"; + inline std::string RecomputeName(const std::string& name) { - return name + "_recompute"; + return name + kRecomputeFlag; } } // namespace graph_utils diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index ac619bdc390d3..2d3dd495cfa11 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -193,40 +193,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - // - // Note 2: Here we use default typo order (which tries to BFS from the outputs, - // so the nearest node to graph output will be visited last). So in reversed default typo order, - // the neareast node to graph output will be visited first. - // Imagine there is a such subgraph - // input1 input2 input3 - // \ | / - // multiple layers - // | - // node M - // labels-------|----- - // \ | | - // node1 | | - // \ | | - // node2 / | - // \ / | - // node loss / - // | / - // YieldOp node1_recompute - // | / - // \ node2 recompute - // \ / - // node loss_grad - // | - // critical grad path - // - // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added - // at last because we do this following reversed topological order. Then node1_recompute node will have lowest - // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then - // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. - // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes - // in this case. - - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { @@ -251,6 +218,41 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } if (recomputed_node_count > 0) { + // Note 1: Critical Path Impact in Priority-Based Topological Sort + // + // Setting and comparing critical path impact is essential in scenarios where all nodes in the priority queue + // are of low priority. This comparison helps determine which recompute node to select to unblock the backward + // critical path. Consider a scenario where: + // - A recompute node subgraph, NodeSubgraph-L, exists within transformer layer N (e.g., NodeSubgraph-5 in + // layer 5, NodeSubgraph-3 in layer 3). + // - Node-A-IN-5 within NodeSubgraph-5 depends on Node-B-IN-0 within NodeSubgraph-0. + // - The priority queue contains nodes from NodeSubgraph-0 to NodeSubgraph-5. + // In MemoryOptimizer recompute scenarios, we append nodes starting from NodeSubgraph-5 down to NodeSubgraph-0. + // Relying solely on node index for comparison could lead to: + // 1) Sequential output of nodes in NodeSubgraph-5 (sorted by ascending node index within the subgraph). + // 2) Blocking of Node-A-IN-5's execution until Node-B-IN-0 is executed. + // 3) Sequential output of nodes from NodeSubgraph-4 to NodeSubgraph-1. + // 4) Execution of NodeSubgraph-0 nodes, allowing Node-A-IN-5 and subsequent NodeSubgraph-5 nodes to execute. + // 5) Execution of remaining NodeSubgraph-0 nodes. + // + // This process can significantly delay the execution of Node-A-IN-5, blocking other NodeSubgraph-5 nodes. + // Since NodeSubgraph-5 nodes are on the critical path, triggering their dependencies timely is crucial to + // ensure their execution as early as possible, ahead of other layers. This necessity led to the introduction + // of critical path impact. + // + // Note 2: Defining Critical Path Impact + // Critical path impact is a metric representing a node's influence on the critical path. It is determined + // during MemoryOptimizer's operation as follows: + // 1) Sort graphs without recompute optimization to establish a baseline topological order. + // 2) Apply recompute optimization. + // 3) Identify recompute boundary nodes (recompute nodes not consumed by others). + // 4) For each boundary node, calculate the minimum topological order of all output nodes. + // The minimum value indicates the earliest need for the recompute node's execution. + // We assign std::numeric_limits::max() - min_topological_order as the critical path impact. + // 5) For other recompute nodes, assign the maximum critical path impact of their output nodes. + ORT_ENFORCE(optimizer::memory_optimizer::SetCriticalPathImpact( + graph, node_index_to_its_order_in_topological_sort_map), + "Failed to set critical path impact attribute."); LOGS(logger, INFO) << "Total number of recomputed nodes: " << recomputed_node_count; } @@ -319,7 +321,7 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, self_contained_outputs_map[output] = new_output_args.back(); } - Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + "_recompute", + Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + graph_utils::kRecomputeFlag, node_to_duplicate->OpType(), "Recompute of " + node_to_duplicate->Name(), new_input_args, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 37ac1c4950ecd..a957c42ababd6 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -11,6 +11,7 @@ #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" #include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "orttraining/core/graph/recompute_graph_utils.h" #include "core/common/string_utils.h" #include "core/framework/data_types.h" #include "core/optimizer/utils.h" @@ -144,6 +145,20 @@ const InlinedHashMap& GetAllowedRecompu {20, {0}}, }, }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, { utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), { @@ -162,27 +177,6 @@ const InlinedHashMap& GetAllowedRecompu {14, {}}, }, }, - { - utils::GetFullQualifiedOpName("Expand", kOnnxDomain), - { - {8, {1}}, // Ignore the shape. - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Cos", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), - { - // The axis input is trivial - {11, {1}}, - {14, {1}}, - }, - }, { utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), { @@ -199,12 +193,25 @@ const InlinedHashMap& GetAllowedRecompu {19, {}}, }, }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, { utils::GetFullQualifiedOpName("FastGelu", kMSDomain), { {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), + { + {1, {1}}, // ignore the indices + }, + }, { utils::GetFullQualifiedOpName("Gather", kOnnxDomain), { @@ -225,6 +232,17 @@ const InlinedHashMap& GetAllowedRecompu {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {11, {}}, + {13, {}}, + }, + }, { utils::GetFullQualifiedOpName("Less", kOnnxDomain), { @@ -244,6 +262,27 @@ const InlinedHashMap& GetAllowedRecompu {14, {}}, }, }, + { + utils::GetFullQualifiedOpName("Neg", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), + { + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), + { + {1, {1, 2}}, // ignore the indices and unflatten_dims + }, + }, { utils::GetFullQualifiedOpName("Range", kOnnxDomain), { @@ -676,6 +715,20 @@ void NodesInTopoOrderToString(gsl::span nodes_in_topological_ } } +/** + * @brief Check whether a node is a recompute node added by MemoryOptimizer or not. + */ +bool IsRecomputeNode(const Node* n) { + static size_t recompute_suffix_len = std::string(graph_utils::kRecomputeFlag).size(); + + std::string_view name = n->Name(); + if (name.size() < recompute_suffix_len) { + return false; + } + + return name.compare(name.size() - recompute_suffix_len, recompute_suffix_len, graph_utils::kRecomputeFlag) == 0; +} + } // namespace Status ParseProbeConfigFromString(std::string_view recompute_probe_config, ProbeConfig& probe_config) { @@ -797,4 +850,102 @@ std::string NodeRecomputePlan::GetNodesInTopoOrderStr() const { return subgraph_str_representation; } +bool SetCriticalPathImpact(Graph& graph, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map) { + // Loop through all the nodes in the graph and find the recompute nodes, categorize them into two groups: + // group 1: recompute nodes that are used only by other non-recompute nodes. + // group 2: recompute nodes that are used by some(>=0) non-recompute nodes and some(>=1) recompute nodes + InlinedHashMap> group1_recompute_nodes; + InlinedHashMap, InlinedVector>> group2_recompute_nodes; + InlinedHashMap recompute_node_to_its_critical_path_node_order_map; + for (const auto& n1 : graph.Nodes()) { + if (IsRecomputeNode(&n1)) { + InlinedVector non_recompute_consumers; + InlinedVector recompute_consumers; + for (auto o_iter = n1.OutputEdgesBegin(); o_iter != n1.OutputEdgesEnd(); ++o_iter) { + const auto& output = *o_iter; + const Node* consumer = graph.GetNode(output.GetNode().Index()); + if (IsRecomputeNode(consumer)) { + recompute_consumers.push_back(consumer); + } else { + non_recompute_consumers.push_back(consumer); + } + } + + if (recompute_consumers.empty()) { + group1_recompute_nodes.insert({&n1, non_recompute_consumers}); + } else { + group2_recompute_nodes.insert({&n1, {non_recompute_consumers, recompute_consumers}}); + } + } + } + + // Loop group1_recompute_nodes, get the minimal value of execution order + // from node_index_to_its_order_in_topological_sort_map for its output nodes. + for (const auto& [non_recompute_node, non_recompute_consumers] : group1_recompute_nodes) { + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), + "Cannot find the order for non-recompute consumer node: ", consumer->Name()); + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + + recompute_node_to_its_critical_path_node_order_map.insert({non_recompute_node, max_impact}); + } + + // Then at this point, all "boudary" recompute nodes are marked for its critical path node order. + // Next, loop group2_recompute_nodes: + // 1). for each recompute node's non-recompute consumers, find the minimal value of + // execution order from node_index_to_its_order_in_topological_sort_map; + // 2). for each recompute node's recompute consumers, find the minimal value of execution order from + // recompute_node_to_its_critical_path_node_order_map. + // + // Be noted, we loop the node in reversed topological order, to make sure "boudary" recompute nodes' + // parent recompute nodes are processed first, then those parent recompute nodes' parent recompute nodes. + GraphViewer updated_graph_viewer(graph); + const auto& updated_node_ids = updated_graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + for (int i = static_cast(updated_node_ids.size()) - 1; i >= 0; --i) { + Node* p_node = graph.GetNode(updated_node_ids[i]); + + if (p_node == nullptr) { + continue; + } + + if (group2_recompute_nodes.find(p_node) != group2_recompute_nodes.end()) { + const auto& [non_recompute_consumers, recompute_consumers] = group2_recompute_nodes.at(p_node); + int64_t max_impact = 0; + for (const auto& consumer : non_recompute_consumers) { + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer->Index()); + ORT_ENFORCE(it != node_index_to_its_order_in_topological_sort_map.end(), + "Cannot find the order for non-recompute consumer node: ", consumer->Name()); + // The smaller the order, then the bigger impact it has. + max_impact = std::max(max_impact, std::numeric_limits::max() - static_cast(it->second)); + } + + for (const auto& consumer : recompute_consumers) { + auto it = recompute_node_to_its_critical_path_node_order_map.find(consumer); + ORT_ENFORCE(it != recompute_node_to_its_critical_path_node_order_map.end(), + "Cannot find the critical path order for recompute node: ", consumer->Name()); + max_impact = std::max(max_impact, it->second); + } + + recompute_node_to_its_critical_path_node_order_map.insert({p_node, max_impact}); + } + } + + // Finally, loop through recompute_node_to_its_critical_path_node_order_map, add the attribute + // for each recompute node, which will be used for priority based graph ordering. + bool modified = false; + for (const auto& [recompute_node, order] : recompute_node_to_its_critical_path_node_order_map) { + Node* mutable_node = graph.GetNode(recompute_node->Index()); + mutable_node->AddAttribute(kRecomputeNodeCriticalPathImpact, static_cast(order)); + modified = true; + } + + return modified; +} + } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ac1021f5eb83b..319829fdec7a0 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -169,4 +169,17 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap bool compromise_stashed_activation, bool& can_compromise_stashed_activation); +/** + * @brief Set the critical path impact for recompute nodes as an attribute. + * The impact is a int64_t value, which will be respected during Priority-Based topology sort. + * The bigger it is, the earlier the node will be executed. + * + * @param graph The graph to iterate and update. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * @return true if the graph is modified, false otherwise. + */ +bool SetCriticalPathImpact(Graph& graph, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map); + } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 360095dea6697..2869598f0a221 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -18,6 +18,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" + #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/session/inference_session.h" @@ -26,6 +27,7 @@ #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "orttraining/core/graph/recompute_graph_utils.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" @@ -223,20 +225,20 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { for (auto& consumer : consumers) { if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { + if (consumer->Name().find(graph_utils::kRecomputeFlag) != std::string::npos) { recompute_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node != nullptr); ASSERT_EQ(recompute_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); + ASSERT_TRUE(recompute_ln_node_parent_add_or_ln_node->Name().find(graph_utils::kRecomputeFlag) == std::string::npos); } else { original_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT)); original_ln_node_parent_add_or_ln_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(original_ln_node_parent_add_or_ln_node); ASSERT_EQ(original_ln_node_parent_add_or_ln_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); - ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find("_recompute") == std::string::npos); + ASSERT_TRUE(original_ln_node_parent_add_or_ln_node->Name().find(graph_utils::kRecomputeFlag) == std::string::npos); } } else if (consumer->OpType().compare("LayerNormalizationGrad") == 0) { input_layer_norm_grad_node = consumer; @@ -262,14 +264,14 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { for (auto& consumer : consumers) { if (consumer->OpType().compare("LayerNormalization") == 0) { - if (consumer->Name().find("_recompute") != std::string::npos) { + if (consumer->Name().find(graph_utils::kRecomputeFlag) != std::string::npos) { recompute_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); recompute_ln_node_parent_add_node = graph.GetProducerNode(consumer->InputDefs()[0]->Name()); ASSERT_TRUE(recompute_ln_node_parent_add_node); ASSERT_EQ(recompute_ln_node_parent_add_node->OpType(), "Add"); ASSERT_EQ(recompute_ln_node_parent_add_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); - ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find("_recompute") != std::string::npos); + ASSERT_TRUE(recompute_ln_node_parent_add_node->Name().find(graph_utils::kRecomputeFlag) != std::string::npos); } else { original_ln_node = consumer; ASSERT_EQ(consumer->Priority(), static_cast(ExecutionPriority::DEFAULT));