Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve priority-based ordering for recompute #20117

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";

constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";

// For Priority based graph topology sorting.
#ifdef ENABLE_TRAINING
// For priority based graph topology sorting.
constexpr const char* kBackwardNodeAttributeName = "__backwardpass";
constexpr const char* kRecomputeNodeCriticalPathImpact = "__recompute_critical_path_impact";
#endif

} // namespace onnxruntime
117 changes: 106 additions & 11 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,61 @@
}

#ifdef ENABLE_TRAINING
// nodes of forward pass will be output first
auto n1_attrs = n1->GetAttributes();
auto n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) ||
(n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n2_is_forward = static_cast<int64_t>(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) ||
(n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;

// Sorting factors for training scenarios.
if (n1_priority == static_cast<int>(ExecutionPriority::DEFAULT)) {
// If both nodes are normal, prioritize outputting the forward pass node.
//
// Note 1: This preference arises from producer-consumer node pairs not separated by "YieldOp".
// The producer (forward pass, contributing to YieldOp inputs) and consumer (backward pass,
// used for gradient computation) should output in forward order to save memory.
//
// Note 2: MemoryOptimizer marks nodes as forward by backtracking from YieldOp's inputs.
// Nodes reached by this backtracking, identified through their inputs, are tagged as forward.
//
// The nodes of forward pass will be output first
auto n1_attrs = n1->GetAttributes();
auto n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) ||
(n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n2_is_forward = static_cast<int64_t>(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) ||
(n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;
}
} else if (n1_priority == static_cast<int>(ExecutionPriority::LOCAL_LOW)) {
// If both are low priority nodes, we prefer to output nodes with bigger impact first.
// Only recompute scenarios will set the critical path impact attribute.
//
// Note 1: Importance of Critical Path Impact in Topological Sorting
// In recompute scenarios, it's crucial to identify which node to execute to unblock the
// critical path. This ensures nodes in the critical path are executed without delay.
// For more details, refer to MemoryOptimizer's implementation.
//
// Note 2: Defining Critical Path Impact
// Critical path impact is a value set during MemoryOptimizer's operation to prioritize
// node execution. It's calculated based on the topological order of nodes and their
// dependencies, ensuring timely execution of critical nodes. For more details, refer
// to MemoryOptimizer's implementation.
//
// Note 3: This trick is not necessarily bound to LOCAL_LOW priority nodes, but we are using it for
// recompue in MemoryOptimizer, so we add the check there. Feel free to revisit the check if it is
// useful for other priorities.
//
// The nodes of bigger impact pass will be output first
const auto& n1_attrs = n1->GetAttributes();
const auto& n2_attrs = n2->GetAttributes();
int64_t n1_impact = (n1_attrs.find(kRecomputeNodeCriticalPathImpact) != n1_attrs.cend())
? static_cast<int64_t>(n1_attrs.at(kRecomputeNodeCriticalPathImpact).i())
: -1;
int64_t n2_impact = (n2_attrs.find(kRecomputeNodeCriticalPathImpact) != n2_attrs.cend())
? static_cast<int64_t>(n2_attrs.at(kRecomputeNodeCriticalPathImpact).i())
: -1;
if (n1_impact != -1 && n2_impact != -1) {
return n2_impact > n1_impact;
}
}

#endif

// otherwise, nodes with lower index will be output first
Expand Down Expand Up @@ -130,11 +175,61 @@
}
#endif
#if !defined(ORT_MINIMAL_BUILD)
std::vector<NodeIndex> nodes_in_topological_order_with_priority;
graph.KahnsTopologicalSort(
[this](const Node* n) {
nodes_in_topological_order_with_priority_.push_back(n->Index());
[&nodes_in_topological_order_with_priority](const Node* n) {
nodes_in_topological_order_with_priority.push_back(n->Index());
},
PriorityNodeCompare());

// Tune the order a bit.
// If a node is used by a consumer node, but the execution order of the consumer node is later than the producer node,
// we should move the producer node to right before the consumer node. In this case, the producer node will only be executed

Check warning on line 187 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:187: Lines should be <= 120 characters long [whitespace/line_length] [2]
// when needed and the memory can be released earlier. We do this in reversed topological order to hanlde the single-input-single_output

Check warning on line 188 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:188: Lines should be <= 120 characters long [whitespace/line_length] [2]
// node chains.
InlinedVector<NodeIndex> node_in_reversed_order;
node_in_reversed_order.reserve(nodes_in_topological_order_with_priority.size());
for (auto it = nodes_in_topological_order_with_priority.rbegin(); it != nodes_in_topological_order_with_priority.rend(); ++it) {

Check warning on line 192 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:192: Lines should be <= 120 characters long [whitespace/line_length] [2]
const Node* node = graph_->GetNode(*it);

if (node->GetOutputEdgesCount() != 1) {
// Don't need tune, just add it to the front of reversed_order.
node_in_reversed_order.push_back(node->Index());
continue;
}

// Handle the "High priority nodes" differently
// So, it may break the computation order also when recompute is enabled.
// But as ShapeInputMerge is introduced, there is much less chance to let recompute subgraph consumed by a normal Shape

Check warning on line 203 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:203: Lines should be <= 120 characters long [whitespace/line_length] [2]
// or Size node. (TODO: pengwa): observe from real models and see if we need to handle this case.
if (node->OpType() == "Shape" || node->OpType() == "Size") {
node_in_reversed_order.push_back(node->Index());
continue;
}

// If node is PythonOpGrad, and its attribute func_name string does not start with "onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction",

Check warning on line 210 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:210: Lines should be <= 120 characters long [whitespace/line_length] [2]
// we skip to make sure the weight accumulation nodes are executed as early as possible (free buffer and unblock subsquent CPU work).

Check warning on line 211 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "subsquent" is a misspelling of "subsequent" Raw Output: ./onnxruntime/core/graph/graph_viewer.cc:211:117: "subsquent" is a misspelling of "subsequent"

Check warning on line 211 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:211: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (node->OpType() == "PythonOpGrad") {
const auto& attrs = node->GetAttributes();
auto it = attrs.find("func_name");
ORT_ENFORCE(it != attrs.end());
if (it->second.s().find("onnxruntime.training.ortmodule._mem_efficient_grad_mgmt.ParamRetrievalFunction") != std::string::npos) {

Check warning on line 216 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:216: Lines should be <= 120 characters long [whitespace/line_length] [2]
node_in_reversed_order.push_back(node->Index());
continue;
}
}

const Node* consumer = &(*(node->OutputNodesBegin()));
// Insert the consumer node right after the producer node. (Remember the order is reversed here).
auto it_consumer = std::find(node_in_reversed_order.begin(), node_in_reversed_order.end(), consumer->Index());
ORT_ENFORCE(it_consumer != node_in_reversed_order.end());
node_in_reversed_order.insert(it_consumer + 1, node->Index()); // Then node is inserted right after the consumer node.

Check warning on line 226 in onnxruntime/core/graph/graph_viewer.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/graph_viewer.cc:226: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

nodes_in_topological_order_with_priority_.insert(
nodes_in_topological_order_with_priority_.end(),
node_in_reversed_order.rbegin(),
node_in_reversed_order.rend());
#endif

if (filter_info_) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/optimizer/gemm_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
new_gemm_node.AddAttribute("transB", static_cast<int64_t>(transB));
new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f());
new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f());
new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType());

graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node);

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
break;

case TransformerLevel::Level2:
rules.push_back(std::make_unique<GemmTransposeFusion>());
// No level2 rules available today
break;

Expand Down Expand Up @@ -251,6 +252,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
} break;

case TransformerLevel::Level2: {
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
if (rule_transformer != nullptr) {
transformers.emplace_back(std::move(rule_transformer));
}

// we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be
// applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2).
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator), kCpuExecutionProvider));
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/core/optimizer/propagate_cast_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,20 @@

using OpsSetType = InlinedHashSet<std::string_view>;
static const OpsSetType level1_fp16_allow_set =
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
{

Check warning on line 174 in onnxruntime/core/optimizer/propagate_cast_ops.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/optimizer/propagate_cast_ops.cc:174: { should almost always be at the end of the previous line [whitespace/braces] [4]
/* Layerout change operators */
"Expand",
"PadAndUnflatten",
"Reshape",
"Split",
"Squeeze",
"Transpose",
"Unsqueeze",
/* Revisted element-wise operators. */
"Gelu",
"Relu",
"Tanh",
};
static const OpsSetType level2_fp16_allow_set = {
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};

Expand Down
4 changes: 3 additions & 1 deletion orttraining/orttraining/core/graph/recompute_graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
namespace onnxruntime {
namespace graph_utils {

constexpr const char* kRecomputeFlag = "_recompute";

inline std::string RecomputeName(const std::string& name) {
return name + "_recompute";
return name + kRecomputeFlag;
}

} // namespace graph_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,40 +193,7 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
// The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended
// earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier
// layers.
//
// Note 2: Here we use default typo order (which tries to BFS from the outputs,
// so the nearest node to graph output will be visited last). So in reversed default typo order,
// the neareast node to graph output will be visited first.
// Imagine there is a such subgraph
// input1 input2 input3
// \ | /
// multiple layers
// |
// node M
// labels-------|-----
// \ | |
// node1 | |
// \ | |
// node2 / |
// \ / |
// node loss /
// | /
// YieldOp node1_recompute
// | /
// \ node2 recompute
// \ /
// node loss_grad
// |
// critical grad path
//
// In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added
// at last because we do this following reversed topological order. Then node1_recompute node will have lowest
// priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then
// node1_recompute will be run at last, affecting the backward critical path, which is not what we want.
// Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes
// in this case.

const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT);
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr) {
Expand All @@ -251,6 +218,41 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
}

if (recomputed_node_count > 0) {
// Note 1: Critical Path Impact in Priority-Based Topological Sort
//
// Setting and comparing critical path impact is essential in scenarios where all nodes in the priority queue
// are of low priority. This comparison helps determine which recompute node to select to unblock the backward
// critical path. Consider a scenario where:
// - A recompute node subgraph, NodeSubgraph-L, exists within transformer layer N (e.g., NodeSubgraph-5 in
// layer 5, NodeSubgraph-3 in layer 3).
// - Node-A-IN-5 within NodeSubgraph-5 depends on Node-B-IN-0 within NodeSubgraph-0.
// - The priority queue contains nodes from NodeSubgraph-0 to NodeSubgraph-5.
// In MemoryOptimizer recompute scenarios, we append nodes starting from NodeSubgraph-5 down to NodeSubgraph-0.
// Relying solely on node index for comparison could lead to:
// 1) Sequential output of nodes in NodeSubgraph-5 (sorted by ascending node index within the subgraph).
// 2) Blocking of Node-A-IN-5's execution until Node-B-IN-0 is executed.
// 3) Sequential output of nodes from NodeSubgraph-4 to NodeSubgraph-1.
// 4) Execution of NodeSubgraph-0 nodes, allowing Node-A-IN-5 and subsequent NodeSubgraph-5 nodes to execute.
// 5) Execution of remaining NodeSubgraph-0 nodes.
//
// This process can significantly delay the execution of Node-A-IN-5, blocking other NodeSubgraph-5 nodes.
// Since NodeSubgraph-5 nodes are on the critical path, triggering their dependencies timely is crucial to
// ensure their execution as early as possible, ahead of other layers. This necessity led to the introduction
// of critical path impact.
//
// Note 2: Defining Critical Path Impact
// Critical path impact is a metric representing a node's influence on the critical path. It is determined
// during MemoryOptimizer's operation as follows:
// 1) Sort graphs without recompute optimization to establish a baseline topological order.
// 2) Apply recompute optimization.
// 3) Identify recompute boundary nodes (recompute nodes not consumed by others).
// 4) For each boundary node, calculate the minimum topological order of all output nodes.
// The minimum value indicates the earliest need for the recompute node's execution.
// We assign std::numeric_limits<int64_t>::max() - min_topological_order as the critical path impact.
// 5) For other recompute nodes, assign the maximum critical path impact of their output nodes.
ORT_ENFORCE(optimizer::memory_optimizer::SetCriticalPathImpact(
graph, node_index_to_its_order_in_topological_sort_map),
"Failed to set critical path impact attribute.");
LOGS(logger, INFO) << "Total number of recomputed nodes: " << recomputed_node_count;
}

Expand Down Expand Up @@ -319,7 +321,7 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph,
self_contained_outputs_map[output] = new_output_args.back();
}

Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + "_recompute",
Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + graph_utils::kRecomputeFlag,
node_to_duplicate->OpType(),
"Recompute of " + node_to_duplicate->Name(),
new_input_args,
Expand Down
Loading
Loading