Skip to content

Commit

Permalink
[QNN EP] Fix topological node unit traversal during validation (#17913)
Browse files Browse the repository at this point in the history
### Description
We need to ensure that tensors are first created and validated by their
producers. If we don't, then builders that need to modify their outputs
may not be able to do so if consumers are processed first (due to
caching of tensors). For example, the Tanh builder may need to override
its output quant param for 16-bit QDQ. I've encountered a scenario
(while working on a partner model) where the override was not being
correctly applied due to the graph traversal order.

I tried to fix this bug in a previous
[PR](#17877 (comment)),
but my fix was incorrect.
  • Loading branch information
adrianlizarraga authored and jchen351 committed Oct 18, 2023
1 parent 815f66c commit 78984d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,14 @@ Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wr
const float scale = output_info.quant_param.scaleOffsetEncoding.scale;

LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values "
<< "of <" << offset << ", " << scale << ">. QNN EP will override the original values.";
<< "of <" << offset << ", " << scale << ">. QNN EP will override the original values for output "
<< output_name;
}
}

ORT_RETURN_IF(qnn_model_wrapper.IsQnnTensorWrapperExist(output_name),
"QNN EP is unable to override output quantization parameters for ", op_type.c_str(),
" operator. Node name: ", node_unit.Name().c_str(), ", output name: ", output_name.c_str());
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ
: QNN_TENSOR_TYPE_NATIVE;
QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param,
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,15 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
for (size_t i = 0; i < node_indices.size(); i++) {
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));

// Get the node_unit associated with the node. Note that the node may not be the node_unit's target node.
const NodeUnit* node_unit = node_unit_map.at(node);

// Visiting 'nodes' in topological order does not guarantee that 'node_units' are
// also visited in topological order. Skip this node if it is not the node_unit's target node
// to ensure 'node_units' are visited in topological order.
if (node != &node_unit->GetNode()) {
continue;
}
const bool supported = IsNodeSupported(qnn_model_wrapper,
*node_unit,
node_unit_supported_result,
Expand All @@ -256,7 +264,10 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
<< "] name: [" << node_unit->Name()
<< "]";
if (supported) {
supported_nodes.insert(node);
// If the node_unit is supported, add all of its nodes to the supported list.
for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
supported_nodes.insert(node_in_group);
}
}
}

Expand Down

0 comments on commit 78984d2

Please sign in to comment.