Skip to content

Commit

Permalink
Avoid call to Node::ToProto on first Graph::Resolve.
Browse files Browse the repository at this point in the history
Better alternative to #19469
  • Loading branch information
skottmckay committed Apr 12, 2024
1 parent 327fb1f commit 10d526e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
16 changes: 16 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,22 @@ class Node {

// Reference to the function template defined in the model.
const FunctionTemplate* func_template_ = nullptr;

// set/clear NodeProto that the Node was created from.
// Set by Graph ctor when loading a model from file.
// Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs.
void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) {
original_node_proto_ = node_proto;
}

const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const {
return original_node_proto_;
}

// NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling
// Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are
// unchanged from the original model.
const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr;
#endif

// Execution priority, lower value for higher priority
Expand Down
40 changes: 30 additions & 10 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2583,9 +2583,17 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
{
auto status = Status::OK();
ORT_TRY {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
// if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip
// the ToProto call.
if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) {
checker::check_node(*orig_node_proto, ctx, lsc);
// clear original as we don't know if the node will be modified once the Graph::Resolve completes.
node.SetOriginalNodeProto(nullptr);
} else {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
Expand Down Expand Up @@ -3123,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto,
attributes[attr.name()] = attr;
}

return AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());
Node& new_node = AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());

// Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to
// calling onnx::check_node
// NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain,
// and that's what onnx::check_node uses during validation.
// The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this.
// node_proto is const so we can't do the same here.
if (node_proto.domain() != kOnnxDomainAlias) {
new_node.SetOriginalNodeProto(&node_proto);
}

return new_node;
}

static flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
Expand Down

0 comments on commit 10d526e

Please sign in to comment.