Skip to content

Commit

Permalink
Fix topological node unit traversal during validation
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Oct 12, 2023
1 parent 809c890 commit 02fc6eb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,14 @@ Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wr
const float scale = output_info.quant_param.scaleOffsetEncoding.scale;

LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values "
<< "of <" << offset << ", " << scale << ">. QNN EP will override the original values.";
<< "of <" << offset << ", " << scale << ">. QNN EP will override the original values for output "
<< output_name;
}
}

ORT_RETURN_IF(qnn_model_wrapper.IsQnnTensorWrapperExist(output_name),
"QNN EP is unable to override output quantization parameters for ", op_type.c_str(),
" operator. Node name: ", node_unit.Name().c_str(), ", output name: ", output_name.c_str());
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ
: QNN_TENSOR_TYPE_NATIVE;
QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param,
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,15 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
for (size_t i = 0; i < node_indices.size(); i++) {
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));

// Get the node_unit associated with the node. Note that the node may not be the node_unit's target node.
const NodeUnit* node_unit = node_unit_map.at(node);

// Visiting 'nodes' in topological order does not guarantee that 'node_units' are
// also visited in topological order. Skip this node if it is not the node_unit's target node
// to ensure 'node_units' are visited in topological order.
if (node != &node_unit->GetNode()) {
continue;
}
const bool supported = IsNodeSupported(qnn_model_wrapper,
*node_unit,
node_unit_supported_result,
Expand All @@ -256,7 +264,10 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
<< "] name: [" << node_unit->Name()
<< "]";
if (supported) {
supported_nodes.insert(node);
// If the node_unit is supported, add all of its nodes to the supported list.
for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
supported_nodes.insert(node_in_group);
}
}
}

Expand Down

0 comments on commit 02fc6eb

Please sign in to comment.