Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize onnxruntime::InferenceSession::Initialize with focus on GrapViewer. For large models the speedup of this function can be up to 3x. #19080

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,21 @@ class Node {
Class to provide const access to Node instances iterated via an EdgeConstIterator. */
class NodeConstIterator {
public:
NodeConstIterator(EdgeConstIterator p_iter);
NodeConstIterator(EdgeConstIterator p_iter) { m_iter = p_iter; }
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved

bool operator==(const NodeConstIterator& p_other) const;
bool operator==(const NodeConstIterator& p_other) const {
return m_iter == p_other.m_iter;
}

bool operator!=(const NodeConstIterator& p_other) const;
bool operator!=(const NodeConstIterator& p_other) const {
return m_iter != p_other.m_iter;
}

void operator++();
void operator--();
void operator++() { ++m_iter; }
void operator--() { --m_iter; }

const Node& operator*() const;
const Node* operator->() const;
const Node& operator*() const { return (*m_iter).GetNode(); }
const Node* operator->() const { return &(operator*()); };

private:
EdgeConstIterator m_iter;
Expand Down Expand Up @@ -394,6 +398,9 @@ class Node {
/** Gets the Node's attributes. */
const NodeAttributes& GetAttributes() const noexcept { return attributes_; }

/** @returns true if the Node is a forward node, false otherwise. **/
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved
bool isForwardNode() const noexcept { return isForwardNode_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the default compiled binaries even train a model or is a specialized compilation needed ?

Copy link
Contributor

@skottmckay skottmckay Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this training specific and the additional code to track whether it's a forward node or not could be inside #if defined(ENABLE_TRAINING)?

nit: IsForwardNode/is_forward_node_ would be consistent with the coding standards.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the PriorityNodeCompare class in graph_viewer.cc.

Reading graph_viewer and this comment

    auto const& n1_attrs = n1->GetAttributes();
    auto const& n2_attrs = n2->GetAttributes();
    int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||

makes me curious, is this purely for visualization? The graph_viewer ist used by a lot of operations in TransformGraph. Does 'will be output first' mean for printing or is it really required for the graph transformation?

Is the information about training available at runtime as well? In this case we could pass the Information to the graph viewer and skip this expensive portion of code.

Researching further into avoiding hash computation, it probably be best to have a special key type which precomputes the hash to avoid the hash value computation, e.g.

https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1661r1.html
https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0920r2.html


#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Remove the specified attribute from this Node */
bool ClearAttribute(const std::string& attr_name);
Expand Down Expand Up @@ -457,7 +464,7 @@ class Node {
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;

/** Gets the execution ProviderType that this node will be executed by. */
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
ProviderType const& GetExecutionProviderType() const noexcept { return execution_provider_type_; }
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved

/** Sets the execution ProviderType that this Node will be executed by. */
void SetExecutionProviderType(ProviderType execution_provider_type) {
Expand Down Expand Up @@ -626,6 +633,10 @@ class Node {
// Execution priority, lower value for higher priority
int priority_ = 0;

// True is Node is a forwardNode and thus doesn't contain a attribute
// named kBackwardNodeAttributeName. False otherwise.
bool isForwardNode_;

// set from op_->SinceVersion() or via deserialization when OpSchema is not available
int since_version_ = -1;

Expand Down
60 changes: 25 additions & 35 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,34 +528,6 @@ Node::EdgeEnd::EdgeEnd(const Node& node) noexcept
: EdgeEnd(node, INT_MAX, INT_MAX) {
}

Node::NodeConstIterator::NodeConstIterator(EdgeConstIterator p_iter) {
m_iter = p_iter;
}

bool Node::NodeConstIterator::operator==(const NodeConstIterator& p_other) const {
return m_iter == p_other.m_iter;
}

bool Node::NodeConstIterator::operator!=(const NodeConstIterator& p_other) const {
return m_iter != p_other.m_iter;
}

void Node::NodeConstIterator::operator++() {
++m_iter;
}

void Node::NodeConstIterator::operator--() {
--m_iter;
}

const Node& Node::NodeConstIterator::operator*() const {
return (*m_iter).GetNode();
}

const Node* Node::NodeConstIterator::operator->() const {
return &(operator*());
}

void Node::SetPriority(int priority) noexcept {
priority_ = priority;
}
Expand Down Expand Up @@ -878,6 +850,7 @@ void Node::Init(std::string_view name,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
isForwardNode_ = true;
Copy link
Contributor Author

@mtavenrath mtavenrath Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the forward node on each change is kind of ugly and risky. With regards to performance it should be unproblematic since the string compare is quite efficient doing a length check first before actually comparing bytes.

The reason for caching isForwardNode_ is each check for the kBackwardNodeAttributeName within the PriorityNodeCompare actually computes the hash of the string which is the costly part.

Alternate solutions would change the container and precompute the hash of kBackwardNodeAttributeName and incorperate it into the key. The downside with this solution is that it'd work only for cases where colisions are handled by lists instead of multiple iterations of hash keys. #Closed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isForwardNode_ = true;

Seems to be a duplicate of the same below.

name_ = name;
op_type_ = op_type;
description_ = description;
Expand All @@ -898,7 +871,12 @@ void Node::Init(std::string_view name,
if (attributes) {
attributes_ = *attributes;

isForwardNode_ = true;
Copy link
Contributor Author

@mtavenrath mtavenrath Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not particular happy tracking the kBackwardAttributeName at multiple places. Is there a chance to have a common pair of functions to add/remove new nodes to minimize breaking changes in the future?
#Closed

for (auto& name_to_attr : attributes_) {
if (!isForwardNode_ && name_to_attr.first == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}

if (utils::HasGraph(name_to_attr.second)) {
#if !defined(ORT_MINIMAL_BUILD)
CreateSubgraph(name_to_attr.first);
Expand Down Expand Up @@ -942,6 +920,9 @@ void Node::CreateSubgraph(const std::string& attr_name) {
#endif // !defined(ORT_MINIMAL_BUILD)

void Node::AddAttributeProto(AttributeProto value) {
if (value.name() == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}
utils::SetNodeAttribute(std::move(value), attributes_);
if (graph_) {
graph_->SetGraphResolveNeeded();
Expand Down Expand Up @@ -978,6 +959,7 @@ ADD_ATTR_IMPLS(TypeProto)
#undef ADD_ATTR_LIST_IMPL
#undef ADD_ATTR_IMPLS

// TODO why isn't attr_name a const&?
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved
void Node::AddAttribute(std::string attr_name, GraphProto value) {
// Do not move attr_name as it is needed below
AttributeProto a = utils::MakeAttribute(attr_name, std::move(value));
Expand All @@ -993,7 +975,11 @@ void Node::AddAttribute(std::string attr_name, GraphProto value) {
bool Node::ClearAttribute(const std::string& attr_name) {
graph_->SetGraphResolveNeeded();
graph_->SetGraphProtoSyncNeeded();
return attributes_.erase(attr_name) > 0;
size_t erased = attributes_.erase(attr_name);
if (erased && attr_name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
return erased > 0;
Copy link
Contributor

@skottmckay skottmckay Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the training code and the usage of this attribute. When is the attribute added, and once added is it actually ever removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with training as well. IMHO it doesn't matter if this attribute gets removed or not in normal workflows. What matters is that it can be removed and not checking for removal would change the behaviour and potentially even introduce bugs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it doesn't seem like a standard attribute, so I would like to understand how it's used before adding more stuff to the production code to handle it theoretically changing. There's a binary size and maintenance cost, and if it's a purely internal value with specific usage it may be better to validate the usage remains as expected via unit tests.

Based on this github code search it has a very internal name, only seems to be set in a training optimizer, and only seems to be read in the graph_viewer code.

Unless there's some external usage of this magic value outside of ORT, it seems like it would be simpler for a Node to have a bool member that is directly set by the training optimizer instead of the indirect costly usage of a specially named attribute.

@askhade do you know if this special value is used outside of ORT and must be set in the Node attributes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.

@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note - when finding usage to update you need to search for both the kBackwardNodeAttributeName constant as well as the string "__backwardpass". Ideally we can make all places use the constant.

https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20kBackwardNodeAttributeName&type=code
https://github.com/search?q=repo%3Amicrosoft%2Fonnxruntime%20__backwardpass&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askhade I'm happy with any timeline as it is not in the a month+ timeline.

Copy link
Contributor

@pengwa pengwa Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code and IMO it is OK to make this a book member of the Node class instead of making it a named attribute. There is widespread usage of this in the code so the change should not be very cumbersome. @mtavenrath let me know if you have any questions regarding this. Tagging @pengwa to validate this.

@mtavenrath what timeline are you targeting for this change? We may need a couple of days to get this reviewed from Peng.

FYI @askhade @skottmckay Yes, it is a purely internal used attributes, and

  1. it is firstly introduced as a backward-data-range specific perf improvement in https://github.com/microsoft/onnxruntime/blame/dfeda9019cfed2d6df5bcacc54269c7de481bdee/onnxruntime/core/providers/rocm/rocm_kernel.h#L29.

  2. A second usage pattern is: in priority based topo ordering, we consider the backward tagged node has lower priority than forward node, in training code path

    Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {
    .

While both usage are having similar logic when set the backward pass to be true, e.g:

  for (auto node_index : node_topology_list) {
    auto& node = *graph.GetNode(node_index);

    if (node.OpType() == "YieldOp") {
      is_backward_pass = true;
    }

in d5d6924#diff-8d8d103ec215ba8edb8ab23e876080adfd60f6f377084ffeca041c8b4f189a2cR13

and

// Find the YieldOp node.
  Node* yield_op_node = nullptr;
  for (auto& node : graph.Nodes()) {
    if (node.OpType() == "YieldOp") {
      yield_op_node = &node;
      break;
    }
  }

in

Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {

Plus "YieldOp" is used in ORTModule (e.g. --enable_training) build only. So I think it is possible we restrict the getforward/setforward in ENABLE_TRAINING macro. While this may need some change in onnxruntime/core/session/provider_bridge_ort.cc to wrap the new bool property ( Not sure whether your tried build/running your local code with ROCM, while I feel the change should be needed to make the ROCM ep code work.).

Copy link
Contributor

@pengwa pengwa Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I noticed for a while. Both training and inference build can use priority based order, it is indeed needed for some training features, while I don't know how much value it brings for model inferencing.

If most inference users don't have such a need, maybe we can load the nodes_in_topological_order_with_priority_ in lazy mode, e.g. we only initialize it when first time user needed it via GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED) .

@skottmckay @askhade

}

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand All @@ -1003,7 +989,11 @@ int Node::PruneRemovableAttributes(gsl::span<const std::string> removable_attrib
graph_->SetGraphProtoSyncNeeded();
int n_removed = 0;
for (const auto& name : removable_attributes) {
n_removed += static_cast<int>(attributes_.erase(name));
bool erased = attributes_.erase(name);
if (erased && name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
n_removed += static_cast<int>(erased);
}
can_be_saved_ = can_be_saved_ && n_removed == 0;
return n_removed;
Expand Down Expand Up @@ -1821,13 +1811,13 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from,
#if !defined(ORT_MINIMAL_BUILD)
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
std::vector<size_t> in_degree(MaxNodeIndex(), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::vector<size_t>

Prefer InlinedVector

std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
std::vector<NodeIndex> topo_order;

for (auto& node : Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
in_degree.insert({node.Index(), input_edge_count});
in_degree[node.Index()] = input_edge_count;
if (input_edge_count == 0) {
to_visit.push(&node);
}
Expand Down Expand Up @@ -2044,7 +2034,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
}
}

std::vector<TypeProto> InferredOutputTypes() const { return node_output_types_; }
std::vector<TypeProto> const& InferredOutputTypes() const { return node_output_types_; }
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved

const AttributeProto* getAttribute(const std::string& name) const override {
auto& attribute_value_map = node_.GetAttributes();
Expand Down Expand Up @@ -2240,7 +2230,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
// Number of inputs corresponding to the i-th argument.
const int arg_count = node.InputArgCount()[i];
// The i-th formal parameter definition.
auto op_formal_parameter = op.inputs()[i];
auto const &op_formal_parameter = op.inputs()[i];
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved

// Check all <arg_count> actual parameters (corresponding to the k-th input)
// match the formal parameter definition (i-th argument).
Expand Down Expand Up @@ -2345,7 +2335,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso

const int num_formal_params = gsl::narrow_cast<int>(op.outputs().size());
auto operand_index = std::min(i, num_formal_params - 1);
auto op_formal_parameter = op.outputs().at(operand_index);
auto const &op_formal_parameter = op.outputs().at(operand_index);
mtavenrath marked this conversation as resolved.
Show resolved Hide resolved

const TypeProto& onnx_inferred_type = onnx_inferred_types[i];
DataType existing_type = output_def->Type();
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
struct PriorityNodeCompare {
inline bool IsHighPri(const Node* n) const {
// local statics so we can compare std::strings in the checks
static const std::string shape_op("Shape");
static const std::string size_op("Size");
static constexpr std::string_view shape_op("Shape");
static constexpr std::string_view size_op("Size");

const auto& op_type = n->OpType();
return op_type == shape_op || op_type == size_op;
Expand All @@ -36,11 +36,11 @@ struct PriorityNodeCompare {
}

// nodes of forward pass will be output first
auto n1_attrs = n1->GetAttributes();
auto n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) ||
auto const& n1_attrs = n1->GetAttributes();
auto const& n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change shaves off nearly 2s from the original 20s of TransformGraph and will be maximum efficient when loading models for inference.

For training models the hash computation is still done once. Before the hash has potentially been computed twice, one time for find and one time for at. By fetching the iterator of std::fine and using it later to get the value of i the number of hash computation could be halved (saving ~1s instead of 2s).

If it's known that the model is used for inference only it'd be great if PriorityNodeCompare could skip this test altogether. This could be achieved most efficient by making this a template class and specialize for inference / training.

(n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n2_is_forward = static_cast<int64_t>(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) ||
int64_t n2_is_forward = static_cast<int64_t>(n2->isForwardNode()) ||
(n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;
Expand Down