Skip to content

Commit

Permalink
Zhijxu/fix toposort (#18705)
Browse files Browse the repository at this point in the history
in training, shape/size need to be executed immediately when it's ok to
be executed and thus to save memory if possible;

the toposort logic is enhanced before, while didn't take of the
"shape->size" pattern, which make the following size op will not show up
in toposort result.
  • Loading branch information
zhijxu-MS authored Dec 5, 2023
1 parent e066fca commit 2b3050b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
: ConstGraphNodes::NodeFilterFunc(nullptr))},
filter_info_{filter_info} {
std::vector<const Node*> leaf_nodes;
#ifdef ENABLE_TRAINING
// Keep the info of shape and size nodes and their parents so that after topological sort, we can move them
// right after their parents. This is to make sure the shape and size nodes are executed right after their parents
// so it's possible the input tensor memory can be released as soon as possible. This is especially important
// for non-CPU devices or for training case where some gradient graphs use only shape/size of tensors from forward.
InlinedHashSet<NodeIndex> shape_size_nodes;
InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>> shape_size_parents;
#endif
for (auto& node : graph_->Nodes()) {
// This is a leaf node (without any output node)
if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
Expand All @@ -72,6 +74,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
root_nodes_.push_back(node.Index());
}
#ifdef ENABLE_TRAINING
if ((node.OpType() == "Shape" || node.OpType() == "Size") && node.InputEdgesBegin() != node.InputEdgesEnd()) {
shape_size_nodes.insert(node.Index());
NodeIndex parent = node.InputNodesBegin()->Index();
Expand All @@ -81,6 +84,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
shape_size_parents[parent].push_back(node.Index());
}
}
#endif
}

graph.ReverseDFSFrom(
Expand All @@ -90,21 +94,24 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
nodes_in_topological_order_.push_back(n->Index());
},
NodeCompare());

#ifdef ENABLE_TRAINING
auto original = std::move(nodes_in_topological_order_);
nodes_in_topological_order_.reserve(original.size());
InlinedHashSet<NodeIndex> visited;
for (auto& node : original) {
if (shape_size_nodes.find(node) != shape_size_nodes.end()) {
if (visited.find(node) != visited.end()) {
continue;
}
nodes_in_topological_order_.push_back(node);
visited.insert(node);
if (shape_size_parents.find(node) != shape_size_parents.end()) {
for (auto& following_node : shape_size_parents[node]) {
nodes_in_topological_order_.push_back(following_node);
visited.insert(following_node);
}
}
}

#endif
#if !defined(ORT_MINIMAL_BUILD)
graph.KahnsTopologicalSort(
[this](const Node* n) {
Expand Down

0 comments on commit 2b3050b

Please sign in to comment.