Skip to content

Commit

Permalink
Make execution order an option for GraphViewerToProto() (microsoft#20411
Browse files Browse the repository at this point in the history
)

**Current issue:**

Once ORT gets the capability from EP's GetCapability(), it creates a
graph viewer based on the capability as below:
`viewers.push_back(std::make_unique<GraphViewer>(graph,
*cur_capability.sub_graph));` or see the code
[here](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/graph_partitioner.cc#L458).

At this point, the graph viewer has the chance to generate the wrong
order of `nodes_in_topological_order_` when calling
[Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107),
so that during EP Compile(), EP might create the "wrong nodes ordering"
model proto from the graph viewer when calling
[GraphViewerToProto()](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_proto_serializer.cc#L37)
because of the `nodes_in_topological_order_`.

This is a problem for TRT EP to refit weights to the "weightless"
engine. Since the engine is built from the model proto provided by TRT
EP and the weights is in the original onnx model. The model proto and
the orignal onnx model are not the same in terms of node ordering which
makes TRT complain when refitting.

**The original model (subgraph of ResNet50):**
<img width="442" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/54722500/bb9a641d-f2f2-46c3-aebf-4084a08ff289">

**The serialized model proto generated by TRT EP:**
(The highlighted part has the wrong node order compared to the original
model.)
<img width="340" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/54722500/bbc6bf34-f960-4753-9474-a18ebc2dc48b">

**The solution 1:**
Change default comparator to `NodeCompare::operator() {return
n1->Index() > n2->Index();}`

The root cause of the different node order between original model and EP
generated model is from graph viewer [generating
](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_viewer.cc#L107)the
different `nodes_in_topological_order_`. Modifying the
`NodeCompare::operator()` for sorting can fix the problem.

The `NodeCompare::operator()` will be used in
[Graph::ReverseDFSFrom](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1760)
where the input nodes of the current node will be
[sorted](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L1802)
based on node index.
Due to the sorted nodes will be pushed into a stack which later
determines the final topological node order in a "first in, last out"
approach, the larger node index should be pushed into the stack first.
So that we can get a topological node order aligns with smaller index
node comes first.

**The solution 2 (This PR uses this solution):**
Use priority-based BFS for topological sort in GraphViewerToProto().
  • Loading branch information
chilo-ms authored Apr 25, 2024
1 parent 21b3cbc commit bbc30fe
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 10 deletions.
7 changes: 4 additions & 3 deletions onnxruntime/core/graph/graph_proto_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace onnxruntime {
void GraphViewerToProto(const GraphViewer& graph_view,
ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializer,
bool include_outer_scope_args) {
bool include_outer_scope_args,
ExecutionOrder order = ExecutionOrder::DEFAULT) {
graph_proto.set_name(graph_view.Name());
graph_proto.set_doc_string(graph_view.Description());

Expand All @@ -34,7 +35,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,
}

// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) {
const gsl::not_null<ONNX_NAMESPACE::NodeProto*> node_proto{graph_proto.add_node()};
const gsl::not_null<const Node*> p_node{graph_view.GetNode(node_idx)};
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
Expand Down Expand Up @@ -62,7 +63,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,

// handle outer scope value which is a constant initializer
if (include_outer_scope_args) {
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) {
for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) {
const auto& node = graph_view.GetNode(node_idx);
for (const auto& input : node->InputDefs()) {
if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) {
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/graph/graph_proto_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,9 @@

namespace onnxruntime {

void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args);
void GraphViewerToProto(const GraphViewer& graph_view,
ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializer,
bool include_outer_scope_args,
ExecutionOrder order);
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,11 @@ struct ProviderHost {
virtual const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) = 0;
virtual const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0;

virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0;
virtual void GraphViewer__ToProto(const GraphViewer* p,
ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializers,
bool include_outer_scope_args,
int execution_order) noexcept = 0;
virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0;

// Path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,12 @@ class GraphViewer final {
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const { return g_host->GraphViewer__GetNodesInTopologicalOrder(this); }
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); }

void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); }
void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializers,
bool include_outer_scope_args,
int execution_order = 0) const {
g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order);
}
const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); }

GraphViewer() = delete;
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2103,7 +2103,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
auto graph_viewer = graph_build.CreateGraphViewer();
auto model = graph_viewer->CreateModel(*GetLogger());
auto model_proto = model->ToProto();
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);

// ORT's default topological sort is using reversed DFS.
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
// the model proto that has different node ordering compared to original onnx model.
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

std::string string_buf;
Expand Down Expand Up @@ -2499,7 +2504,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
// Reconstruct graph proto from fused node's function body
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);

// ORT's default topological sort is using reversed DFS.
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
// the model proto that has different node ordering compared to original onnx model.
graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
model_proto->SerializeToString(string_buf);
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,12 @@ struct ProviderHostImpl : ProviderHost {

const std::vector<NodeIndex>& GraphViewer__GetNodesInTopologicalOrder(const GraphViewer* p) override { return p->GetNodesInTopologicalOrder(); }
const std::vector<const NodeArg*>& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept override { return p->GetInputsIncludingInitializers(); }
void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override {
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args);
void GraphViewer__ToProto(const GraphViewer* p,
ONNX_NAMESPACE::GraphProto& graph_proto,
bool include_initializers,
bool include_outer_scope_args,
int execution_order) noexcept override {
GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast<ExecutionOrder>(execution_order));
}
const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); }

Expand Down

0 comments on commit bbc30fe

Please sign in to comment.