From 2b3050bb0c89537d67e213f657ec56a7ec21d47e Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Tue, 5 Dec 2023 17:36:00 +0800 Subject: [PATCH] Zhijxu/fix toposort (#18705) in training, shape/size need to be executed immediately when it's ok to be executed and thus to save memory if possible; the toposort logic is enhanced before, while didn't take of the "shape->size" pattern, which make the following size op will not show up in toposort result. --- onnxruntime/core/graph/graph_viewer.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 98f4897552a14..b1e07714cd3c8 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -57,12 +57,14 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) : ConstGraphNodes::NodeFilterFunc(nullptr))}, filter_info_{filter_info} { std::vector leaf_nodes; +#ifdef ENABLE_TRAINING // Keep the info of shape and size nodes and their parents so that after topological sort, we can move them // right after their parents. This is to make sure the shape and size nodes are executed right after their parents // so it's possible the input tensor memory can be released as soon as possible. This is especially important // for non-CPU devices or for training case where some gradient graphs use only shape/size of tensors from forward. InlinedHashSet shape_size_nodes; InlinedHashMap> shape_size_parents; +#endif for (auto& node : graph_->Nodes()) { // This is a leaf node (without any output node) if (node.OutputNodesBegin() == node.OutputNodesEnd()) { @@ -72,6 +74,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) if (node.InputEdgesBegin() == node.InputEdgesEnd()) { root_nodes_.push_back(node.Index()); } +#ifdef ENABLE_TRAINING if ((node.OpType() == "Shape" || node.OpType() == "Size") && node.InputEdgesBegin() != node.InputEdgesEnd()) { shape_size_nodes.insert(node.Index()); NodeIndex parent = node.InputNodesBegin()->Index(); @@ -81,6 +84,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) shape_size_parents[parent].push_back(node.Index()); } } +#endif } graph.ReverseDFSFrom( @@ -90,21 +94,24 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) nodes_in_topological_order_.push_back(n->Index()); }, NodeCompare()); - +#ifdef ENABLE_TRAINING auto original = std::move(nodes_in_topological_order_); nodes_in_topological_order_.reserve(original.size()); + InlinedHashSet visited; for (auto& node : original) { - if (shape_size_nodes.find(node) != shape_size_nodes.end()) { + if (visited.find(node) != visited.end()) { continue; } nodes_in_topological_order_.push_back(node); + visited.insert(node); if (shape_size_parents.find(node) != shape_size_parents.end()) { for (auto& following_node : shape_size_parents[node]) { nodes_in_topological_order_.push_back(following_node); + visited.insert(following_node); } } } - +#endif #if !defined(ORT_MINIMAL_BUILD) graph.KahnsTopologicalSort( [this](const Node* n) {