diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index b16d52dbdab68..3b417a362d2cc 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -621,6 +621,22 @@ class Node { // Reference to the function template defined in the model. const FunctionTemplate* func_template_ = nullptr; + + // set/clear NodeProto that the Node was created from. + // Set by Graph ctor when loading a model from file. + // Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs. + void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) { + original_node_proto_ = node_proto; + } + + const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const { + return original_node_proto_; + } + + // NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling + // Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are + // unchanged from the original model. + const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr; #endif // Execution priority, lower value for higher priority diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 305122c56b865..2220b9cd1db70 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2583,9 +2583,17 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { { auto status = Status::OK(); ORT_TRY { - NodeProto node_proto; - node.ToProto(node_proto); - checker::check_node(node_proto, ctx, lsc); + // if this is first Graph::Resolve call, we may have a NodeProto that was set on the Node so we can skip + // the ToProto call. + if (const NodeProto* orig_node_proto = node.GetOriginalNodeProto(); orig_node_proto) { + checker::check_node(*orig_node_proto, ctx, lsc); + // clear original as we don't know if the node will be modified once the Graph::Resolve completes. + node.SetOriginalNodeProto(nullptr); + } else { + NodeProto node_proto; + node.ToProto(node_proto); + checker::check_node(node_proto, ctx, lsc); + } } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -3123,13 +3131,25 @@ Node& Graph::AddNode(const NodeProto& node_proto, attributes[attr.name()] = attr; } - return AddNode(node_proto.name(), - node_proto.op_type(), - node_proto.doc_string(), - input_defs, - output_defs, - &attributes, - node_proto.domain()); + Node& new_node = AddNode(node_proto.name(), + node_proto.op_type(), + node_proto.doc_string(), + input_defs, + output_defs, + &attributes, + node_proto.domain()); + + // Perf optimization: temporarily set NodeProto in Node so we don't need to call Node::ToProto prior to + // calling onnx::check_node + // NOTE: We don't handle a node with kOnnxDomainAlias. The entry in schema_registry_ uses kOnnxDomain, + // and that's what onnx::check_node uses during validation. + // The Node ctor automatically converts kOnnxDomainAlias to kOnnxDomain to handle this. + // node_proto is const so we can't do the same here. + if (node_proto.domain() != kOnnxDomainAlias) { + new_node.SetOriginalNodeProto(&node_proto); + } + + return new_node; } static flatbuffers::Offset>> diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 3e0d94e94e48c..3a01f2c8d95ad 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1956,12 +1956,9 @@ TEST_F(PlannerTest, TestCpuIf) { sess_opt.graph_optimization_level = TransformerLevel::Default; InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx")); - auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); - ASSERT_TRUE(status.IsOK()); - status = sess.Load(); - ASSERT_TRUE(status.IsOK()); - status = sess.Initialize(); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(sess.Load()); + ASSERT_STATUS_OK(sess.Initialize()); auto& sess_state = const_cast(sess.GetSessionState()); const auto& exe_plan = sess_state.GetExecutionPlan()->execution_plan; @@ -1971,7 +1968,7 @@ TEST_F(PlannerTest, TestCpuIf) { exe_plan[1]->steps_[7]->GetNodeIndex() == 7) { // there must be a wait before cpu If node static const std::string WaitOnEPStep = "WaitOnEPStep"; - ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); + ASSERT_EQ(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()), WaitOnEPStep); } } diff --git a/onnxruntime/test/testdata/multi_stream_models/cpu_if.onnx b/onnxruntime/test/testdata/multi_stream_models/cpu_if.onnx index b9374feff46f9..e97a8bffa7860 100644 Binary files a/onnxruntime/test/testdata/multi_stream_models/cpu_if.onnx and b/onnxruntime/test/testdata/multi_stream_models/cpu_if.onnx differ