Skip to content

Commit

Permalink
QNN EP: Fuse certain DQ -> Q sequences into a single QNN Convert oper…
Browse files Browse the repository at this point in the history
…ator.
  • Loading branch information
adrianlizarraga committed Feb 13, 2024
1 parent 5c7e6b2 commit 5ced778
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 41 deletions.
43 changes: 43 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,49 @@ bool IsQDQPairSupported(
}
}

bool IsDQQConversion(
const Node& dq_node, const Node& q_node,
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path) {
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();

// Q/DQ contains optional input is not supported
// non-scalar Q/DQ scale and zero point needs are not supported
if (dq_input_defs.size() != InputIndex::TOTAL_COUNT ||
q_input_defs.size() != InputIndex::TOTAL_COUNT ||
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) ||
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) ||
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) ||
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) {
return false;
}

// if Q/DQ scale and zero point are not constant, return false
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
if (nullptr == q_zp_tensor_proto ||
nullptr == dq_zp_tensor_proto ||
nullptr == q_scale_tensor_proto ||
nullptr == dq_scale_tensor_proto) {
return false;
}

// check Q/DQ have same scale type and different zero point type
Initializer q_zp(*q_zp_tensor_proto, model_path);
Initializer q_scale(*q_scale_tensor_proto, model_path);
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);

return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type());
}

bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
bool zero_point_exists = false;
if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) {
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ bool IsQDQPairSupported(
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path);

// Check if a DQ -> Q sequence represents a conversion in quantization data type.
// Example of uint8 to uint16:
// Dequantize (uint8 to float) -> Quantize (float to uint16)
// Requires:
// 1. Q/DQ doesn't have optional input.
// 2. scale and zero-point are constant scalars.
// 3. Q and DQ have the same scale *type* and different zero-point *types*.
bool IsDQQConversion(
const Node& dq_node, const Node& q_node,
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path);

// Check if DQ is supported in extended level QDQ transformers. It requires:
// 1. DQ doesn't have optional input.
// 2. scale and zero point is constant scalar
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r

void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

struct HandleConvertResult {
Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine
// whether a DQ -> Q sequence was successfully merged into a Convert.
const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence.
// Set to nullptr if this node unit could not be merged into a Convert.
};

/**
* Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
* one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
*
* \param qnn_model_wrapper The QNN model that is being built.
* \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence.
* \param logger The logger.
* \param do_op_validation True if should call QNN operator validation APIs.
* \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer
* to the Q node unit that was successfully merged with the provided DQ node unit.
*/
HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& maybe_dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,

Check warning on line 117 in onnxruntime/core/providers/qnn/builder/op_builder_factory.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/op_builder_factory.h:117: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
const logging::Logger& logger,
bool do_op_validation);
} // namespace qnn
} // namespace onnxruntime
103 changes: 103 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/graph_utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/common/safeint.h"
#include "onnx/defs/data_type_utils.h"

#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values).

namespace onnxruntime {
namespace qnn {

class ConvertOpBuilder : public BaseOpBuilder {
public:
ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder);

Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
const logging::Logger& logger,
bool do_op_validation) const ORT_MUST_USE_RESULT;
};

Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<std::string> input_names;

// Process the input from the DQ node
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names));

// Process the output from the Q node. Override the QNN operator type to "Convert".
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {},
logger, do_op_validation, QNN_OP_CONVERT));
return Status::OK();
}

HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& maybe_dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
const logging::Logger& logger,
bool do_op_validation) {
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();

// Looking for a standalone DQ to start the sequence.
if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
return {};
}

const Node& dq_node = maybe_dq_node_unit.GetNode();

// DQ must have a single Q child. DQ must not produce a graph output.
auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName);
if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) {
return {};
}

const Node& q_node = *children[0];
const auto q_node_unit_it = node_unit_map.find(&q_node);

if (q_node_unit_it == node_unit_map.end()) {
return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr};
}

const NodeUnit* q_node_unit = q_node_unit_it->second;

// Q child must not already be part of a QDQ NodeUnit (i.e., be standalone).
if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
return {};
}

auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

// DQ and Q must have equal scale type and different zp type.
if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) {
return {};
}

ConvertOpBuilder op_builder;

LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name()
<< "] dq_node optype: [" << dq_node.OpType()
<< "] q_node name: [" << q_node_unit->Name()
<< "] q_node optype: [" << q_node_unit->OpType()
<< "]";

auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger,
do_op_validation);
return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr};
}

} // namespace qnn
} // namespace onnxruntime
35 changes: 30 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper.");
}

std::unordered_set<const NodeUnit*> handled_node_units;

Check warning on line 117 in onnxruntime/core/providers/qnn/builder/qnn_model.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_model.cc:117: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

// Op builer
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
Expand All @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
// Check whether it's part of NodeUnit
const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map);
// Q, DQ nodes in the node unit only carry the quantization parameters
// Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node)
// Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node)
const std::string& op_type = node_unit.OpType();

if (node != &node_unit.GetNode()) {
continue;
}

if (handled_node_units.count(&node_unit) != 0) {
continue; // Already handled.
}

// Try to convert particular DQ -> Q sequences into QNN Convert op
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
node_unit,
node_unit_map,
logger_,
false /*do_op_validation*/);
ORT_RETURN_IF_ERROR(convert_result.status);

if (convert_result.q_node_unit) {
// Successfully merged DQ -> Q sequence into a QNN Convert op.
// Mark both of these node units as handled.
handled_node_units.insert(&node_unit);
handled_node_units.insert(convert_result.q_node_unit);
continue;
}

LOGS(logger_, VERBOSE) << " node name: [" << node->Name()
<< "] node optype: [" << op_type
<< "] as part of the NodeUnit type: [" << node_unit.OpType()
<< "] name: [" << node_unit.Name()
<< "]";
if (node != &node_unit.GetNode()) {
continue;
}

if (const auto* op_builder = GetOpBuilder(op_type)) {
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_));
}

handled_node_units.insert(&node_unit);
}

ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph.");
Expand Down
88 changes: 53 additions & 35 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}

bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
const logging::Logger& logger) const {
// If we have visited one of the nodes in the node_unit, use the result directly
const auto it = node_unit_supported_result.find(&node_unit);
if (it != node_unit_supported_result.cend()) {
return it->second;
const std::string& op_type = node_unit.OpType();
bool supported = false;
const auto* op_builder = qnn::GetOpBuilder(op_type);
if (op_builder == nullptr) {
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
<< node_unit.OpType() << " node `" << node_unit.Name()
<< "` will not be assigned to QNN EP.";
} else {
const std::string& op_type = node_unit.OpType();

bool supported = false;
const auto* op_builder = qnn::GetOpBuilder(op_type);
if (op_builder == nullptr) {
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
<< node_unit.OpType() << " node `" << node_unit.Name()
<< "` will not be assigned to QNN EP.";
} else {
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
node_unit, logger);
if (Status::OK() != status) {
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
<< "` is not supported: " << status.ErrorMessage();
}
supported = (Status::OK() == status);
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
node_unit, logger);
if (Status::OK() != status) {
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
<< "` is not supported: " << status.ErrorMessage();
}
node_unit_supported_result[&node_unit] = supported;
return supported;
supported = (Status::OK() == status);
}
return supported;
}

std::unordered_set<const Node*>
Expand Down Expand Up @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
if (node != &node_unit->GetNode()) {
continue;
}
const bool supported = IsNodeSupported(qnn_model_wrapper,
*node_unit,
node_unit_supported_result,
logger);
LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] index: [" << node->Index()
<< "] name: [" << node->Name()
<< "] Operator type: [" << node->OpType()
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
<< "] index: [" << node_unit->Index()
<< "] name: [" << node_unit->Name()
<< "]";

if (node_unit_supported_result.count(node_unit) != 0) {
continue; // Already handled this node unit
}

// Try to convert certain standalone DQ -> Q sequences into QNN Convert op
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
*node_unit,
node_unit_map,
logger,
true /*do_op_validation*/);
if (!convert_result.status.IsOK()) {
LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. "
<< "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", "
<< "Message: " << convert_result.status.ErrorMessage();
}

bool supported = false;

if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op
supported = true;

// Mark the Q node unit as handled and supported here so that we don't try to process it again.
node_unit_supported_result.insert({convert_result.q_node_unit, true});
supported_nodes.insert(&convert_result.q_node_unit->GetNode());
} else {
supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger);
LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] index: [" << node->Index()
<< "] name: [" << node->Name()
<< "] Operator type: [" << node->OpType()
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
<< "] index: [" << node_unit->Index()
<< "] name: [" << node_unit->Name()
<< "]";
}

if (supported) {
// If the node_unit is supported, add all of its nodes to the supported list.
for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
supported_nodes.insert(node_in_group);
}
}

node_unit_supported_result.insert({node_unit, supported});
}

return supported_nodes;
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider {

private:
bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
const logging::Logger& logger) const;

std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Loading

0 comments on commit 5ced778

Please sign in to comment.