diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 5b252b5896d23..c8da2461a3e3c 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -8,7 +8,8 @@ namespace onnxruntime { void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, - bool include_outer_scope_args) { + bool include_outer_scope_args, + ExecutionOrder order = ExecutionOrder::DEFAULT) { graph_proto.set_name(graph_view.Name()); graph_proto.set_doc_string(graph_view.Description()); @@ -34,7 +35,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, } // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. - for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const gsl::not_null node_proto{graph_proto.add_node()}; const gsl::not_null p_node{graph_view.GetNode(node_idx)}; // we need to update any GraphProto attributes for subgraphs so that any changes made by things @@ -62,7 +63,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, // handle outer scope value which is a constant initializer if (include_outer_scope_args) { - for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const auto& node = graph_view.GetNode(node_idx); for (const auto& input : node->InputDefs()) { if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { diff --git a/onnxruntime/core/graph/graph_proto_serializer.h b/onnxruntime/core/graph/graph_proto_serializer.h index fe88dd547ff04..43027ef704794 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.h +++ b/onnxruntime/core/graph/graph_proto_serializer.h @@ -7,5 +7,9 @@ namespace onnxruntime { -void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args); +void GraphViewerToProto(const GraphViewer& graph_view, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializer, + bool include_outer_scope_args, + ExecutionOrder order); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1824a82995bce..74c67cc79ddd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -851,7 +851,11 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0; virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; - virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + virtual void GraphViewer__ToProto(const GraphViewer* p, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; // Path diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 3bb938c1a3197..b86c3b2ba5ccc 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -887,7 +887,12 @@ class GraphViewer final { const std::vector& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } - void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order = 0) const { + g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); + } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } GraphViewer() = delete; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f33e9a968ce95..5e2da4990c90d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2103,7 +2103,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto graph_viewer = graph_build.CreateGraphViewer(); auto model = graph_viewer->CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); + + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -2499,7 +2504,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Reconstruct graph proto from fused node's function body auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); + + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9ec6bb0181004..6bd277ec96ac5 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1088,8 +1088,12 @@ struct ProviderHostImpl : ProviderHost { const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); } const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); } - void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { - GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); + void GraphViewer__ToProto(const GraphViewer* p, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }