From a1a68ede5943e9d0c1b7323c80bd5fd9bfca7692 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Thu, 25 Jan 2024 15:14:29 +0100 Subject: [PATCH] Replace __backward attribute by isForwardNode --- include/onnxruntime/core/graph/constants.h | 3 -- include/onnxruntime/core/graph/graph.h | 12 ++++---- onnxruntime/core/graph/graph.cc | 28 ++++--------------- onnxruntime/core/graph/graph_viewer.cc | 8 ++---- .../core/optimizer/matmul_scale_fusion.cc | 6 +--- .../core/optimizer/matmul_transpose_fusion.cc | 5 +--- .../core/optimizer/rocm_blas_alt_impl.cc | 2 +- onnxruntime/core/providers/rocm/rocm_kernel.h | 2 +- .../memory_optimizer/memory_insight.cc | 13 ++++----- 9 files changed, 24 insertions(+), 55 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..7e59aad80cc47 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -55,7 +55,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; -// For Priority based graph topology sorting. -constexpr const char* kBackwardNodeAttributeName = "__backwardpass"; - } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 534429aca4417..fd0aa7793cd5a 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -399,7 +399,10 @@ class Node { const NodeAttributes& GetAttributes() const noexcept { return attributes_; } /** @returns true if the Node is a forward node, false otherwise. **/ - bool isForwardNode() const noexcept { return isForwardNode_; } + bool isForwardNode() const noexcept { return is_forward_node_; } + + /* Sets the forward node status */ + void setForwardNode(bool is_forward_node) noexcept { is_forward_node_ = is_forward_node; } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ @@ -464,7 +467,7 @@ class Node { std::unordered_map> GetAttributeNameToSubgraphMap() const; /** Gets the execution ProviderType that this node will be executed by. */ - ProviderType const& GetExecutionProviderType() const noexcept { return execution_provider_type_; } + ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; } /** Sets the execution ProviderType that this Node will be executed by. */ void SetExecutionProviderType(ProviderType execution_provider_type) { @@ -633,9 +636,8 @@ class Node { // Execution priority, lower value for higher priority int priority_ = 0; - // True is Node is a forwardNode and thus doesn't contain a attribute - // named kBackwardNodeAttributeName. False otherwise. - bool isForwardNode_; + // This node is a forward node if value, otherwise it is a backward node. + bool is_forward_node_; // set from op_->SinceVersion() or via deserialization when OpSchema is not available int since_version_ = -1; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index da33b267eb53f..fdafda93bb900 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -850,7 +850,7 @@ void Node::Init(std::string_view name, gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { - isForwardNode_ = true; + is_forward_node_ = true; name_ = name; op_type_ = op_type; description_ = description; @@ -871,12 +871,8 @@ void Node::Init(std::string_view name, if (attributes) { attributes_ = *attributes; - isForwardNode_ = true; + is_forward_node_ = true; for (auto& name_to_attr : attributes_) { - if (!isForwardNode_ && name_to_attr.first == kBackwardNodeAttributeName) { - isForwardNode_ = false; - } - if (utils::HasGraph(name_to_attr.second)) { #if !defined(ORT_MINIMAL_BUILD) CreateSubgraph(name_to_attr.first); @@ -920,9 +916,6 @@ void Node::CreateSubgraph(const std::string& attr_name) { #endif // !defined(ORT_MINIMAL_BUILD) void Node::AddAttributeProto(AttributeProto value) { - if (value.name() == kBackwardNodeAttributeName) { - isForwardNode_ = false; - } utils::SetNodeAttribute(std::move(value), attributes_); if (graph_) { graph_->SetGraphResolveNeeded(); @@ -959,7 +952,6 @@ ADD_ATTR_IMPLS(TypeProto) #undef ADD_ATTR_LIST_IMPL #undef ADD_ATTR_IMPLS -// TODO why isn't attr_name a const&? void Node::AddAttribute(std::string attr_name, GraphProto value) { // Do not move attr_name as it is needed below AttributeProto a = utils::MakeAttribute(attr_name, std::move(value)); @@ -975,11 +967,7 @@ void Node::AddAttribute(std::string attr_name, GraphProto value) { bool Node::ClearAttribute(const std::string& attr_name) { graph_->SetGraphResolveNeeded(); graph_->SetGraphProtoSyncNeeded(); - size_t erased = attributes_.erase(attr_name); - if (erased && attr_name == kBackwardNodeAttributeName) { - isForwardNode_ = true; - } - return erased > 0; + return attributes_.erase(attr_name) > 0; } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -989,11 +977,7 @@ int Node::PruneRemovableAttributes(gsl::span removable_attrib graph_->SetGraphProtoSyncNeeded(); int n_removed = 0; for (const auto& name : removable_attributes) { - bool erased = attributes_.erase(name); - if (erased && name == kBackwardNodeAttributeName) { - isForwardNode_ = true; - } - n_removed += static_cast(erased); + n_removed += static_cast(attributes_.erase(name)); } can_be_saved_ = can_be_saved_ && n_removed == 0; return n_removed; @@ -1813,7 +1797,7 @@ void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { std::vector in_degree(MaxNodeIndex(), 0); std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector topo_order; for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); @@ -2034,7 +2018,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } } - std::vector const& InferredOutputTypes() const { return node_output_types_; } + std::vector const& InferredOutputTypes() const noexcept { return node_output_types_; } const AttributeProto* getAttribute(const std::string& name) const override { auto& attribute_value_map = node_.GetAttributes(); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 3a4d721466667..cd5dde88daf3b 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -36,12 +36,8 @@ struct PriorityNodeCompare { } // nodes of forward pass will be output first - auto const& n1_attrs = n1->GetAttributes(); - auto const& n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1->isForwardNode()) || - (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2->isForwardNode()) || - (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; + int64_t n1_is_forward = n1->isForwardNode(); + int64_t n2_is_forward = n2->isForwardNode(); if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index b04d794cc9469..675d983ba5402 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -255,11 +255,7 @@ Status ProcessNode( matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); #ifdef USE_ROCM - // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_scale_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } + matmul_scale_node.setForwardNode(node.GetForwardNode()); #endif { diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index 789466778edc6..b7943ed046835 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -407,10 +407,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ matmul_node.SetExecutionProviderType(node.GetExecutionProviderType()); #ifdef USE_ROCM // forward the __backwardpass, if present - auto& attrs = node.GetAttributes(); - if (attrs.count("__backwardpass")) { - matmul_node.AddAttribute("__backwardpass", static_cast(attrs.at("__backwardpass").i())); - } + malmul_node.setForwardPass(node.getForwardPass()); #endif graph_utils::FinalizeNodeFusion(graph, matmul_node, node); diff --git a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc index decb25f565efe..db539ef5bec71 100644 --- a/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc +++ b/onnxruntime/core/optimizer/rocm_blas_alt_impl.cc @@ -26,7 +26,7 @@ Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (is_backward_pass) { - node.AddAttribute(std::string("__backwardpass"), static_cast(1)); + node.setForwardNode(false); modified = true; } } diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index c0b7d4722d3e4..f5a654c0a6352 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -25,7 +25,7 @@ class RocmKernel : public OpKernel { Status Compute(OpKernelContext* p_op_kernel_context) const override { Status s; - auto is_backward_pass = Info().GetAttrOrDefault("__backwardpass", 0); + auto is_backward_pass = !Node().isForwardNode(); if (is_backward_pass) { BackwardPassGuard guard; s = ComputeInternal(p_op_kernel_context); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb6f1..c8266c50c2892 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -197,16 +197,13 @@ Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) { // Set the attribute to true for all backward nodes. for (auto& node : graph.Nodes()) { if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - continue; + if (node.isForwardNode()) { + node.setForwardNode(false); + modified = true; } - node.AddAttribute(kBackwardNodeAttributeName, static_cast(1)); - modified = true; } else { - auto& attrs = node.GetAttributes(); - if (attrs.count(kBackwardNodeAttributeName)) { - node.ClearAttribute(kBackwardNodeAttributeName); + if (!node.isForwardNode()) { + node.setForwardNode(true); modified = true; } }