From 02fc6eb8235a5772627e77a0675df7c172bc847e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 12 Oct 2023 10:53:55 -0700 Subject: [PATCH] Fix topological node unit traversal during validation --- .../qnn/builder/opbuilder/simple_op_builder.cc | 6 +++++- .../core/providers/qnn/qnn_execution_provider.cc | 13 ++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index acdcfdc66bf34..4022307b93ff8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -371,10 +371,14 @@ Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wr const float scale = output_info.quant_param.scaleOffsetEncoding.scale; LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values " - << "of <" << offset << ", " << scale << ">. QNN EP will override the original values."; + << "of <" << offset << ", " << scale << ">. QNN EP will override the original values for output " + << output_name; } } + ORT_RETURN_IF(qnn_model_wrapper.IsQnnTensorWrapperExist(output_name), + "QNN EP is unable to override output quantization parameters for ", op_type.c_str(), + " operator. Node name: ", node_unit.Name().c_str(), ", output name: ", output_name.c_str()); Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index d497bc1c069d2..b456be92412e5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -242,7 +242,15 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, for (size_t i = 0; i < node_indices.size(); i++) { gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + // Get the node_unit associated with the node. Note that the node may not be the node_unit's target node. const NodeUnit* node_unit = node_unit_map.at(node); + + // Visiting 'nodes' in topological order does not guarantee that 'node_units' are + // also visited in topological order. Skip this node if it is not the node_unit's target node + // to ensure 'node_units' are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, node_unit_supported_result, @@ -256,7 +264,10 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, << "] name: [" << node_unit->Name() << "]"; if (supported) { - supported_nodes.insert(node); + // If the node_unit is supported, add all of its nodes to the supported list. + for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node_in_group); + } } }