From 1f7eea4e28becded678aad5dc27aa5f7a33eda66 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 10 Jan 2024 17:42:29 +0100 Subject: [PATCH] Optimize onnxruntime::InferenceSession::Initialize with focus on GraphViewer. For large models the speedup of this function can be up to 3x. --- include/onnxruntime/core/graph/graph.h | 27 ++++++++---- onnxruntime/core/graph/graph.cc | 60 +++++++++++--------------- onnxruntime/core/graph/graph_viewer.cc | 12 +++--- 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200f..534429aca4417 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -294,17 +294,21 @@ class Node { Class to provide const access to Node instances iterated via an EdgeConstIterator. */ class NodeConstIterator { public: - NodeConstIterator(EdgeConstIterator p_iter); + NodeConstIterator(EdgeConstIterator p_iter) { m_iter = p_iter; } - bool operator==(const NodeConstIterator& p_other) const; + bool operator==(const NodeConstIterator& p_other) const { + return m_iter == p_other.m_iter; + } - bool operator!=(const NodeConstIterator& p_other) const; + bool operator!=(const NodeConstIterator& p_other) const { + return m_iter != p_other.m_iter; + } - void operator++(); - void operator--(); + void operator++() { ++m_iter; } + void operator--() { --m_iter; } - const Node& operator*() const; - const Node* operator->() const; + const Node& operator*() const { return (*m_iter).GetNode(); } + const Node* operator->() const { return &(operator*()); }; private: EdgeConstIterator m_iter; @@ -394,6 +398,9 @@ class Node { /** Gets the Node's attributes. */ const NodeAttributes& GetAttributes() const noexcept { return attributes_; } + /** @returns true if the Node is a forward node, false otherwise. **/ + bool isForwardNode() const noexcept { return isForwardNode_; } + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ bool ClearAttribute(const std::string& attr_name); @@ -457,7 +464,7 @@ class Node { std::unordered_map> GetAttributeNameToSubgraphMap() const; /** Gets the execution ProviderType that this node will be executed by. */ - ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; } + ProviderType const& GetExecutionProviderType() const noexcept { return execution_provider_type_; } /** Sets the execution ProviderType that this Node will be executed by. */ void SetExecutionProviderType(ProviderType execution_provider_type) { @@ -626,6 +633,10 @@ class Node { // Execution priority, lower value for higher priority int priority_ = 0; + // True is Node is a forwardNode and thus doesn't contain a attribute + // named kBackwardNodeAttributeName. False otherwise. + bool isForwardNode_; + // set from op_->SinceVersion() or via deserialization when OpSchema is not available int since_version_ = -1; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f71b7ecebcf1a..da33b267eb53f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -528,34 +528,6 @@ Node::EdgeEnd::EdgeEnd(const Node& node) noexcept : EdgeEnd(node, INT_MAX, INT_MAX) { } -Node::NodeConstIterator::NodeConstIterator(EdgeConstIterator p_iter) { - m_iter = p_iter; -} - -bool Node::NodeConstIterator::operator==(const NodeConstIterator& p_other) const { - return m_iter == p_other.m_iter; -} - -bool Node::NodeConstIterator::operator!=(const NodeConstIterator& p_other) const { - return m_iter != p_other.m_iter; -} - -void Node::NodeConstIterator::operator++() { - ++m_iter; -} - -void Node::NodeConstIterator::operator--() { - --m_iter; -} - -const Node& Node::NodeConstIterator::operator*() const { - return (*m_iter).GetNode(); -} - -const Node* Node::NodeConstIterator::operator->() const { - return &(operator*()); -} - void Node::SetPriority(int priority) noexcept { priority_ = priority; } @@ -878,6 +850,7 @@ void Node::Init(std::string_view name, gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { + isForwardNode_ = true; name_ = name; op_type_ = op_type; description_ = description; @@ -898,7 +871,12 @@ void Node::Init(std::string_view name, if (attributes) { attributes_ = *attributes; + isForwardNode_ = true; for (auto& name_to_attr : attributes_) { + if (!isForwardNode_ && name_to_attr.first == kBackwardNodeAttributeName) { + isForwardNode_ = false; + } + if (utils::HasGraph(name_to_attr.second)) { #if !defined(ORT_MINIMAL_BUILD) CreateSubgraph(name_to_attr.first); @@ -942,6 +920,9 @@ void Node::CreateSubgraph(const std::string& attr_name) { #endif // !defined(ORT_MINIMAL_BUILD) void Node::AddAttributeProto(AttributeProto value) { + if (value.name() == kBackwardNodeAttributeName) { + isForwardNode_ = false; + } utils::SetNodeAttribute(std::move(value), attributes_); if (graph_) { graph_->SetGraphResolveNeeded(); @@ -978,6 +959,7 @@ ADD_ATTR_IMPLS(TypeProto) #undef ADD_ATTR_LIST_IMPL #undef ADD_ATTR_IMPLS +// TODO why isn't attr_name a const&? void Node::AddAttribute(std::string attr_name, GraphProto value) { // Do not move attr_name as it is needed below AttributeProto a = utils::MakeAttribute(attr_name, std::move(value)); @@ -993,7 +975,11 @@ void Node::AddAttribute(std::string attr_name, GraphProto value) { bool Node::ClearAttribute(const std::string& attr_name) { graph_->SetGraphResolveNeeded(); graph_->SetGraphProtoSyncNeeded(); - return attributes_.erase(attr_name) > 0; + size_t erased = attributes_.erase(attr_name); + if (erased && attr_name == kBackwardNodeAttributeName) { + isForwardNode_ = true; + } + return erased > 0; } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1003,7 +989,11 @@ int Node::PruneRemovableAttributes(gsl::span removable_attrib graph_->SetGraphProtoSyncNeeded(); int n_removed = 0; for (const auto& name : removable_attributes) { - n_removed += static_cast(attributes_.erase(name)); + bool erased = attributes_.erase(name); + if (erased && name == kBackwardNodeAttributeName) { + isForwardNode_ = true; + } + n_removed += static_cast(erased); } can_be_saved_ = can_be_saved_ && n_removed == 0; return n_removed; @@ -1821,13 +1811,13 @@ void Graph::ReverseDFSFrom(gsl::span from, #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; + std::vector in_degree(MaxNodeIndex(), 0); std::priority_queue, decltype(comp)> to_visit(comp); std::vector topo_order; for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -2044,7 +2034,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } } - std::vector InferredOutputTypes() const { return node_output_types_; } + std::vector const& InferredOutputTypes() const { return node_output_types_; } const AttributeProto* getAttribute(const std::string& name) const override { auto& attribute_value_map = node_.GetAttributes(); @@ -2240,7 +2230,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // Number of inputs corresponding to the i-th argument. const int arg_count = node.InputArgCount()[i]; // The i-th formal parameter definition. - auto op_formal_parameter = op.inputs()[i]; + auto const &op_formal_parameter = op.inputs()[i]; // Check all actual parameters (corresponding to the k-th input) // match the formal parameter definition (i-th argument). @@ -2345,7 +2335,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const int num_formal_params = gsl::narrow_cast(op.outputs().size()); auto operand_index = std::min(i, num_formal_params - 1); - auto op_formal_parameter = op.outputs().at(operand_index); + auto const &op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; DataType existing_type = output_def->Type(); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..3a4d721466667 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -36,11 +36,11 @@ struct PriorityNodeCompare { } // nodes of forward pass will be output first - auto n1_attrs = n1->GetAttributes(); - auto n2_attrs = n2->GetAttributes(); - int64_t n1_is_forward = static_cast(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) || + auto const& n1_attrs = n1->GetAttributes(); + auto const& n2_attrs = n2->GetAttributes(); + int64_t n1_is_forward = static_cast(n1->isForwardNode()) || (n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; - int64_t n2_is_forward = static_cast(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) || + int64_t n2_is_forward = static_cast(n2->isForwardNode()) || (n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2; if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward;