From e220693eca7a724aa3d738a24e6c638a5696c587 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 16 Nov 2023 19:56:05 -0800 Subject: [PATCH] [TensorRT EP] Fix bug for no nodes in subgraph at GetCapability (#18449) It's possible that subgraph of the "If" control flow op has no nodes. TRT EP should consider this kind of subgraph is fully supported by TRT. The faster rcnn model mentioned in this issue https://github.com/microsoft/onnxruntime/issues/17434 is the case. --- .../tensorrt/tensorrt_execution_provider.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd4aa45f83bc8..79f84864a5788 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1829,6 +1829,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, if (sub_graphs.size() != 0) { bool all_subgraphs_are_supported = true; for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { all_subgraphs_are_supported = false; break; @@ -1896,27 +1900,33 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, auto sub_graphs = graph.ParentNode()->GetSubgraphs(); for (auto sub_graph : sub_graphs) { if (sub_graph.get() != &graph.GetGraph()) { - auto sub_graph_veiwer = sub_graph->CreateGraphViewer(); - const int number_of_ort_subgraph_nodes = sub_graph_veiwer->NumberOfNodes(); + auto sub_graph_viewer = sub_graph->CreateGraphViewer(); + const int number_of_ort_subgraph_nodes = sub_graph_viewer->NumberOfNodes(); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; bool subgraph_early_termination = false; - // Another subgraph of "If" control flow has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - if (AllNodesAssignedToSpecificEP(*sub_graph_veiwer, kTensorrtExecutionProvider)) { + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (sub_graph_viewer->NumberOfNodes() == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (AllNodesAssignedToSpecificEP(*sub_graph_viewer, kTensorrtExecutionProvider)) { all_subgraphs_are_supported = true; break; } // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!AllNodesAssignedToSpecificEP(*sub_graph_veiwer, "")) { + else if (!AllNodesAssignedToSpecificEP(*sub_graph_viewer, "")) { all_subgraphs_are_supported = false; break; } // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_veiwer, &subgraph_early_termination); + subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_viewer, &subgraph_early_termination); all_subgraphs_are_supported = IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); break; }