From bbc30feb63dbd267304a0cc26020d9230c739749 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:07:36 -0700 Subject: [PATCH] Make execution order an option for GraphViewerToProto() (#20411) **Current issue:** Once ORT gets the capability from EP's GetCapability(), it creates a graph viewer based on the capability as below: `viewers.push_back(std::make_unique(graph, *cur_capability.sub_graph));` or see the code [here](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/graph_partitioner.cc#L458). At this point, the graph viewer has the chance to generate the wrong order of `nodes_in_topological_order_` when calling [Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107), so that during EP Compile(), EP might create the "wrong nodes ordering" model proto from the graph viewer when calling [GraphViewerToProto()](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_proto_serializer.cc#L37) because of the `nodes_in_topological_order_`. This is a problem for TRT EP to refit weights to the "weightless" engine. Since the engine is built from the model proto provided by TRT EP and the weights is in the original onnx model. The model proto and the orignal onnx model are not the same in terms of node ordering which makes TRT complain when refitting. **The original model (subgraph of ResNet50):** image **The serialized model proto generated by TRT EP:** (The highlighted part has the wrong node order compared to the original model.) image **The solution 1:** Change default comparator to `NodeCompare::operator() {return n1->Index() > n2->Index();}` The root cause of the different node order between original model and EP generated model is from graph viewer [generating ](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107)the different `nodes_in_topological_order_`. Modifying the `NodeCompare::operator()` for sorting can fix the problem. The `NodeCompare::operator()` will be used in [Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1760) where the input nodes of the current node will be [sorted](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1802) based on node index. Due to the sorted nodes will be pushed into a stack which later determines the final topological node order in a "first in, last out" approach, the larger node index should be pushed into the stack first. So that we can get a topological node order aligns with smaller index node comes first. **The solution 2 (This PR uses this solution):** Use priority-based BFS for topological sort in GraphViewerToProto(). --- onnxruntime/core/graph/graph_proto_serializer.cc | 7 ++++--- onnxruntime/core/graph/graph_proto_serializer.h | 6 +++++- .../providers/shared_library/provider_interfaces.h | 6 +++++- .../shared_library/provider_wrappedtypes.h | 7 ++++++- .../tensorrt/tensorrt_execution_provider.cc | 14 ++++++++++++-- onnxruntime/core/session/provider_bridge_ort.cc | 8 ++++++-- 6 files changed, 38 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 5b252b5896d23..c8da2461a3e3c 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -8,7 +8,8 @@ namespace onnxruntime { void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, - bool include_outer_scope_args) { + bool include_outer_scope_args, + ExecutionOrder order = ExecutionOrder::DEFAULT) { graph_proto.set_name(graph_view.Name()); graph_proto.set_doc_string(graph_view.Description()); @@ -34,7 +35,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, } // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. - for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const gsl::not_null node_proto{graph_proto.add_node()}; const gsl::not_null p_node{graph_view.GetNode(node_idx)}; // we need to update any GraphProto attributes for subgraphs so that any changes made by things @@ -62,7 +63,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, // handle outer scope value which is a constant initializer if (include_outer_scope_args) { - for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const auto& node = graph_view.GetNode(node_idx); for (const auto& input : node->InputDefs()) { if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { diff --git a/onnxruntime/core/graph/graph_proto_serializer.h b/onnxruntime/core/graph/graph_proto_serializer.h index fe88dd547ff04..43027ef704794 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.h +++ b/onnxruntime/core/graph/graph_proto_serializer.h @@ -7,5 +7,9 @@ namespace onnxruntime { -void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args); +void GraphViewerToProto(const GraphViewer& graph_view, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializer, + bool include_outer_scope_args, + ExecutionOrder order); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1824a82995bce..74c67cc79ddd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -851,7 +851,11 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0; virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; - virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + virtual void GraphViewer__ToProto(const GraphViewer* p, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; // Path diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 3bb938c1a3197..b86c3b2ba5ccc 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -887,7 +887,12 @@ class GraphViewer final { const std::vector& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } - void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order = 0) const { + g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); + } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } GraphViewer() = delete; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f33e9a968ce95..5e2da4990c90d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2103,7 +2103,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto graph_viewer = graph_build.CreateGraphViewer(); auto model = graph_viewer->CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); + + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -2499,7 +2504,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Reconstruct graph proto from fused node's function body auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); + + // ORT's default topological sort is using reversed DFS. + // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. + // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating + // the model proto that has different node ordering compared to original onnx model. + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9ec6bb0181004..6bd277ec96ac5 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1088,8 +1088,12 @@ struct ProviderHostImpl : ProviderHost { const std::vector& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); } const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); } - void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { - GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); + void GraphViewer__ToProto(const GraphViewer* p, + ONNX_NAMESPACE::GraphProto& graph_proto, + bool include_initializers, + bool include_outer_scope_args, + int execution_order) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }