Skip to content

Commit

Permalink
Optimize onnxruntime::InferenceSession::Initialize with focus on Grap…
Browse files Browse the repository at this point in the history
…hViewer. For large models the speedup of this function can be up to 3x.
  • Loading branch information
mtavenrath committed Jan 10, 2024
1 parent 5f3113e commit 1f7eea4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
27 changes: 19 additions & 8 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,21 @@ class Node {
Class to provide const access to Node instances iterated via an EdgeConstIterator. */
class NodeConstIterator {
public:
NodeConstIterator(EdgeConstIterator p_iter);
NodeConstIterator(EdgeConstIterator p_iter) { m_iter = p_iter; }

bool operator==(const NodeConstIterator& p_other) const;
bool operator==(const NodeConstIterator& p_other) const {
return m_iter == p_other.m_iter;
}

bool operator!=(const NodeConstIterator& p_other) const;
bool operator!=(const NodeConstIterator& p_other) const {
return m_iter != p_other.m_iter;
}

void operator++();
void operator--();
void operator++() { ++m_iter; }
void operator--() { --m_iter; }

const Node& operator*() const;
const Node* operator->() const;
const Node& operator*() const { return (*m_iter).GetNode(); }
const Node* operator->() const { return &(operator*()); };

private:
EdgeConstIterator m_iter;
Expand Down Expand Up @@ -394,6 +398,9 @@ class Node {
/** Gets the Node's attributes. */
const NodeAttributes& GetAttributes() const noexcept { return attributes_; }

/** @returns true if the Node is a forward node, false otherwise. **/
bool isForwardNode() const noexcept { return isForwardNode_; }

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Remove the specified attribute from this Node */
bool ClearAttribute(const std::string& attr_name);
Expand Down Expand Up @@ -457,7 +464,7 @@ class Node {
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;

/** Gets the execution ProviderType that this node will be executed by. */
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
ProviderType const& GetExecutionProviderType() const noexcept { return execution_provider_type_; }

/** Sets the execution ProviderType that this Node will be executed by. */
void SetExecutionProviderType(ProviderType execution_provider_type) {
Expand Down Expand Up @@ -626,6 +633,10 @@ class Node {
// Execution priority, lower value for higher priority
int priority_ = 0;

// True is Node is a forwardNode and thus doesn't contain a attribute
// named kBackwardNodeAttributeName. False otherwise.
bool isForwardNode_;

// set from op_->SinceVersion() or via deserialization when OpSchema is not available
int since_version_ = -1;

Expand Down
60 changes: 25 additions & 35 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,34 +528,6 @@ Node::EdgeEnd::EdgeEnd(const Node& node) noexcept
: EdgeEnd(node, INT_MAX, INT_MAX) {
}

Node::NodeConstIterator::NodeConstIterator(EdgeConstIterator p_iter) {
m_iter = p_iter;
}

bool Node::NodeConstIterator::operator==(const NodeConstIterator& p_other) const {
return m_iter == p_other.m_iter;
}

bool Node::NodeConstIterator::operator!=(const NodeConstIterator& p_other) const {
return m_iter != p_other.m_iter;
}

void Node::NodeConstIterator::operator++() {
++m_iter;
}

void Node::NodeConstIterator::operator--() {
--m_iter;
}

const Node& Node::NodeConstIterator::operator*() const {
return (*m_iter).GetNode();
}

const Node* Node::NodeConstIterator::operator->() const {
return &(operator*());
}

void Node::SetPriority(int priority) noexcept {
priority_ = priority;
}
Expand Down Expand Up @@ -878,6 +850,7 @@ void Node::Init(std::string_view name,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
isForwardNode_ = true;
name_ = name;
op_type_ = op_type;
description_ = description;
Expand All @@ -898,7 +871,12 @@ void Node::Init(std::string_view name,
if (attributes) {
attributes_ = *attributes;

isForwardNode_ = true;
for (auto& name_to_attr : attributes_) {
if (!isForwardNode_ && name_to_attr.first == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}

if (utils::HasGraph(name_to_attr.second)) {
#if !defined(ORT_MINIMAL_BUILD)
CreateSubgraph(name_to_attr.first);
Expand Down Expand Up @@ -942,6 +920,9 @@ void Node::CreateSubgraph(const std::string& attr_name) {
#endif // !defined(ORT_MINIMAL_BUILD)

void Node::AddAttributeProto(AttributeProto value) {
if (value.name() == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}
utils::SetNodeAttribute(std::move(value), attributes_);
if (graph_) {
graph_->SetGraphResolveNeeded();
Expand Down Expand Up @@ -978,6 +959,7 @@ ADD_ATTR_IMPLS(TypeProto)
#undef ADD_ATTR_LIST_IMPL
#undef ADD_ATTR_IMPLS

// TODO why isn't attr_name a const&?
void Node::AddAttribute(std::string attr_name, GraphProto value) {
// Do not move attr_name as it is needed below
AttributeProto a = utils::MakeAttribute(attr_name, std::move(value));
Expand All @@ -993,7 +975,11 @@ void Node::AddAttribute(std::string attr_name, GraphProto value) {
bool Node::ClearAttribute(const std::string& attr_name) {
graph_->SetGraphResolveNeeded();
graph_->SetGraphProtoSyncNeeded();
return attributes_.erase(attr_name) > 0;
size_t erased = attributes_.erase(attr_name);
if (erased && attr_name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
return erased > 0;
}

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand All @@ -1003,7 +989,11 @@ int Node::PruneRemovableAttributes(gsl::span<const std::string> removable_attrib
graph_->SetGraphProtoSyncNeeded();
int n_removed = 0;
for (const auto& name : removable_attributes) {
n_removed += static_cast<int>(attributes_.erase(name));
bool erased = attributes_.erase(name);
if (erased && name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
n_removed += static_cast<int>(erased);
}
can_be_saved_ = can_be_saved_ && n_removed == 0;
return n_removed;
Expand Down Expand Up @@ -1821,13 +1811,13 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from,
#if !defined(ORT_MINIMAL_BUILD)
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
std::vector<size_t> in_degree(MaxNodeIndex(), 0);
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
std::vector<NodeIndex> topo_order;

for (auto& node : Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
in_degree.insert({node.Index(), input_edge_count});
in_degree[node.Index()] = input_edge_count;
if (input_edge_count == 0) {
to_visit.push(&node);
}
Expand Down Expand Up @@ -2044,7 +2034,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
}
}

std::vector<TypeProto> InferredOutputTypes() const { return node_output_types_; }
std::vector<TypeProto> const& InferredOutputTypes() const { return node_output_types_; }

const AttributeProto* getAttribute(const std::string& name) const override {
auto& attribute_value_map = node_.GetAttributes();
Expand Down Expand Up @@ -2240,7 +2230,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
// Number of inputs corresponding to the i-th argument.
const int arg_count = node.InputArgCount()[i];
// The i-th formal parameter definition.
auto op_formal_parameter = op.inputs()[i];
auto const &op_formal_parameter = op.inputs()[i];

// Check all <arg_count> actual parameters (corresponding to the k-th input)
// match the formal parameter definition (i-th argument).
Expand Down Expand Up @@ -2345,7 +2335,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso

const int num_formal_params = gsl::narrow_cast<int>(op.outputs().size());
auto operand_index = std::min(i, num_formal_params - 1);
auto op_formal_parameter = op.outputs().at(operand_index);
auto const &op_formal_parameter = op.outputs().at(operand_index);

const TypeProto& onnx_inferred_type = onnx_inferred_types[i];
DataType existing_type = output_def->Type();
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
struct PriorityNodeCompare {
inline bool IsHighPri(const Node* n) const {
// local statics so we can compare std::strings in the checks
static const std::string shape_op("Shape");
static const std::string size_op("Size");
static constexpr std::string_view shape_op("Shape");
static constexpr std::string_view size_op("Size");

const auto& op_type = n->OpType();
return op_type == shape_op || op_type == size_op;
Expand All @@ -36,11 +36,11 @@ struct PriorityNodeCompare {
}

// nodes of forward pass will be output first
auto n1_attrs = n1->GetAttributes();
auto n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1_attrs.find(kBackwardNodeAttributeName) == n1_attrs.cend()) ||
auto const& n1_attrs = n1->GetAttributes();
auto const& n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||
(n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n2_is_forward = static_cast<int64_t>(n2_attrs.find(kBackwardNodeAttributeName) == n2_attrs.cend()) ||
int64_t n2_is_forward = static_cast<int64_t>(n2->isForwardNode()) ||
(n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;
Expand Down

0 comments on commit 1f7eea4

Please sign in to comment.