Skip to content

Commit

Permalink
Avoid call to Node::ToProto on first Graph::Resolve to improve sessio…
Browse files Browse the repository at this point in the history
…n creation performance. (microsoft#20296)

### Description
<!-- Describe your changes. -->
The first call to Graph::Resolve occurs when creating the Graph instance
when loading an existing model from ModelProto. As the Node instance
will exactly match the source NodeProto there's no need to call
Node::ToProto in this case.

Add a temporary reference to the original NodeProto to avoid the call on
the first Graph::Resolve.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Better alternative to microsoft#19469
  • Loading branch information
skottmckay authored and Ted Themistokleous committed May 7, 2024
1 parent 78f148a commit c3c54d0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
16 changes: 16 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,22 @@ class Node {

// Reference to the function template defined in the model.
const FunctionTemplate* func_template_ = nullptr;

// set/clear NodeProto that the Node was created from.
// Set by Graph ctor when loading a model from file.
// Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs.
void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) {
original_node_proto_ = node_proto;
}

const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const {
return original_node_proto_;
}

// NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling
// Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are
// unchanged from the original model.
const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr;
#endif

// Execution priority, lower value for higher priority
Expand Down
40 changes: 30 additions & 10 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2583,9 +2583,17 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
{
auto status = Status::OK();
ORT_TRY {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
// if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip
// the ToProto call.
if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) {
checker::check_node(*orig_node_proto, ctx, lsc);
// clear original as we don't know if the node will be modified once the Graph::Resolve completes.
node.SetOriginalNodeProto(nullptr);
} else {
NodeProto node_proto;
node.ToProto(node_proto);
checker::check_node(node_proto, ctx, lsc);
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
Expand Down Expand Up @@ -3123,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto,
attributes[attr.name()] = attr;
}

return AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());
Node& new_node = AddNode(node_proto.name(),
node_proto.op_type(),
node_proto.doc_string(),
input_defs,
output_defs,
&attributes,
node_proto.domain());

// Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to
// calling onnx::check_node
// NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain,
// and that's what onnx::check_node uses during validation.
// The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this.
// node_proto is const so we can't do the same here.
if (node_proto.domain() != kOnnxDomainAlias) {
new_node.SetOriginalNodeProto(&node_proto);
}

return new_node;
}

static flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1956,12 +1956,9 @@ TEST_F(PlannerTest, TestCpuIf) {
sess_opt.graph_optimization_level = TransformerLevel::Default;

InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx"));
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
ASSERT_TRUE(status.IsOK());
status = sess.Load();
ASSERT_TRUE(status.IsOK());
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
ASSERT_STATUS_OK(sess.Load());
ASSERT_STATUS_OK(sess.Initialize());

auto& sess_state = const_cast<onnxruntime::SessionState&>(sess.GetSessionState());
const auto& exe_plan = sess_state.GetExecutionPlan()->execution_plan;
Expand All @@ -1971,7 +1968,7 @@ TEST_F(PlannerTest, TestCpuIf) {
exe_plan[1]->steps_[7]->GetNodeIndex() == 7) {
// there must be a wait before cpu If node
static const std::string WaitOnEPStep = "WaitOnEPStep";
ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
ASSERT_EQ(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()), WaitOnEPStep);
}
}

Expand Down
Binary file modified onnxruntime/test/testdata/multi_stream_models/cpu_if.onnx
Binary file not shown.

0 comments on commit c3c54d0

Please sign in to comment.