Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QNN EP: Fuse DQ -> Q sequences into a QNN Convert op #19511

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,49 @@ bool IsQDQPairSupported(
}
}

bool IsDQQConversion(
const Node& dq_node, const Node& q_node,
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path) {
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();

// Q/DQ contains optional input is not supported
// non-scalar Q/DQ scale and zero point needs are not supported
if (dq_input_defs.size() != InputIndex::TOTAL_COUNT ||
q_input_defs.size() != InputIndex::TOTAL_COUNT ||
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) ||
!optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) ||
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) ||
!optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) {
return false;
}

// if Q/DQ scale and zero point are not constant, return false
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
if (nullptr == q_zp_tensor_proto ||
nullptr == dq_zp_tensor_proto ||
nullptr == q_scale_tensor_proto ||
nullptr == dq_scale_tensor_proto) {
return false;
}

// check Q/DQ have same scale type and different zero point type
Initializer q_zp(*q_zp_tensor_proto, model_path);
Initializer q_scale(*q_scale_tensor_proto, model_path);
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);

return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type());
}

bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
bool zero_point_exists = false;
if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) {
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ bool IsQDQPairSupported(
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path);

// Check if a DQ -> Q sequence represents a conversion in quantization data type.
// Example of uint8 to uint16:
// Dequantize (uint8 to float) -> Quantize (float to uint16)
// Requires:
// 1. Q/DQ doesn't have optional input.
// 2. scale and zero-point are constant scalars.
// 3. Q and DQ have the same scale *type* and different zero-point *types*.
bool IsDQQConversion(
const Node& dq_node, const Node& q_node,
const GetConstantInitializerFn& get_const_initializer,
const Path& model_path);

// Check if DQ is supported in extended level QDQ transformers. It requires:
// 1. DQ doesn't have optional input.
// 2. scale and zero point is constant scalar
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,28 @@

void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

struct HandleConvertResult {
Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine
// whether a DQ -> Q sequence was successfully merged into a Convert.
const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence.
// Set to nullptr if this node unit could not be merged into a Convert.
};

/**
* Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
* one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
*
* \param qnn_model_wrapper The QNN model that is being built.
* \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence.
* \param logger The logger.
* \param do_op_validation True if should call QNN operator validation APIs.
* \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer
* to the Q node unit that was successfully merged with the provided DQ node unit.
*/
HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& maybe_dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,

Check warning on line 117 in onnxruntime/core/providers/qnn/builder/op_builder_factory.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/op_builder_factory.h:117: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
const logging::Logger& logger,
bool do_op_validation);
} // namespace qnn
} // namespace onnxruntime
103 changes: 103 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/graph_utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/common/safeint.h"
#include "onnx/defs/data_type_utils.h"

#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values).

namespace onnxruntime {
namespace qnn {

class ConvertOpBuilder : public BaseOpBuilder {
public:
ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder);

Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
const logging::Logger& logger,
bool do_op_validation) const ORT_MUST_USE_RESULT;
};

Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<std::string> input_names;

// Process the input from the DQ node
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names));

// Process the output from the Q node. Override the QNN operator type to "Convert".
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {},
logger, do_op_validation, QNN_OP_CONVERT));
return Status::OK();
}

HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& maybe_dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
const logging::Logger& logger,
bool do_op_validation) {
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();

// Looking for a standalone DQ to start the sequence.
if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
return {};
}

const Node& dq_node = maybe_dq_node_unit.GetNode();

// DQ must have a single Q child. DQ must not produce a graph output.
auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName);
if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) {
return {};
}

const Node& q_node = *children[0];
const auto q_node_unit_it = node_unit_map.find(&q_node);

if (q_node_unit_it == node_unit_map.end()) {
return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr};
}

const NodeUnit* q_node_unit = q_node_unit_it->second;

// Q child must not already be part of a QDQ NodeUnit (i.e., be standalone).
if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
return {};
}

auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
};

// DQ and Q must have equal scale type and different zp type.
if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) {
return {};
}

ConvertOpBuilder op_builder;

LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name()
<< "] dq_node optype: [" << dq_node.OpType()
<< "] q_node name: [" << q_node_unit->Name()
<< "] q_node optype: [" << q_node_unit->OpType()
<< "]";

auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger,
do_op_validation);
return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr};
}

} // namespace qnn
} // namespace onnxruntime
35 changes: 30 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper.");
}

std::unordered_set<const NodeUnit*> handled_node_units;

Check warning on line 117 in onnxruntime/core/providers/qnn/builder/qnn_model.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_model.cc:117: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

// Op builer
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t i = 0; i < node_indices.size(); i++) {
Expand All @@ -122,20 +124,43 @@
// Check whether it's part of NodeUnit
const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map);
// Q, DQ nodes in the node unit only carry the quantization parameters
// Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node)
// Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node)
const std::string& op_type = node_unit.OpType();

if (node != &node_unit.GetNode()) {
continue;
}

if (handled_node_units.count(&node_unit) != 0) {
continue; // Already handled.
}

// Try to convert particular DQ -> Q sequences into QNN Convert op
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
node_unit,
node_unit_map,
logger_,
false /*do_op_validation*/);
ORT_RETURN_IF_ERROR(convert_result.status);

if (convert_result.q_node_unit) {
// Successfully merged DQ -> Q sequence into a QNN Convert op.
// Mark both of these node units as handled.
handled_node_units.insert(&node_unit);
handled_node_units.insert(convert_result.q_node_unit);
continue;
}

LOGS(logger_, VERBOSE) << " node name: [" << node->Name()
<< "] node optype: [" << op_type
<< "] as part of the NodeUnit type: [" << node_unit.OpType()
<< "] name: [" << node_unit.Name()
<< "]";
if (node != &node_unit.GetNode()) {
continue;
}

if (const auto* op_builder = GetOpBuilder(op_type)) {
ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_));
}

handled_node_units.insert(&node_unit);
}

ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph.");
Expand Down
88 changes: 53 additions & 35 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}

bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
const logging::Logger& logger) const {
// If we have visited one of the nodes in the node_unit, use the result directly
const auto it = node_unit_supported_result.find(&node_unit);
if (it != node_unit_supported_result.cend()) {
return it->second;
const std::string& op_type = node_unit.OpType();
bool supported = false;
const auto* op_builder = qnn::GetOpBuilder(op_type);
if (op_builder == nullptr) {
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
<< node_unit.OpType() << " node `" << node_unit.Name()
<< "` will not be assigned to QNN EP.";
} else {
const std::string& op_type = node_unit.OpType();

bool supported = false;
const auto* op_builder = qnn::GetOpBuilder(op_type);
if (op_builder == nullptr) {
LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
<< node_unit.OpType() << " node `" << node_unit.Name()
<< "` will not be assigned to QNN EP.";
} else {
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
node_unit, logger);
if (Status::OK() != status) {
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
<< "` is not supported: " << status.ErrorMessage();
}
supported = (Status::OK() == status);
auto status = op_builder->IsOpSupported(qnn_model_wrapper,
node_unit, logger);
if (Status::OK() != status) {
LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
<< "` is not supported: " << status.ErrorMessage();
}
node_unit_supported_result[&node_unit] = supported;
return supported;
supported = (Status::OK() == status);
}
return supported;
}

std::unordered_set<const Node*>
Expand Down Expand Up @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
if (node != &node_unit->GetNode()) {
continue;
}
const bool supported = IsNodeSupported(qnn_model_wrapper,
*node_unit,
node_unit_supported_result,
logger);
LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] index: [" << node->Index()
<< "] name: [" << node->Name()
<< "] Operator type: [" << node->OpType()
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
<< "] index: [" << node_unit->Index()
<< "] name: [" << node_unit->Name()
<< "]";

if (node_unit_supported_result.count(node_unit) != 0) {
continue; // Already handled this node unit
}

// Try to convert certain standalone DQ -> Q sequences into QNN Convert op
auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
*node_unit,
node_unit_map,
logger,
true /*do_op_validation*/);
if (!convert_result.status.IsOK()) {
LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. "
<< "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", "
<< "Message: " << convert_result.status.ErrorMessage();
}

bool supported = false;

if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op
supported = true;

// Mark the Q node unit as handled and supported here so that we don't try to process it again.
node_unit_supported_result.insert({convert_result.q_node_unit, true});
supported_nodes.insert(&convert_result.q_node_unit->GetNode());
} else {
supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger);
LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] index: [" << node->Index()
<< "] name: [" << node->Name()
<< "] Operator type: [" << node->OpType()
<< "] as part of the NodeUnit type: [" << node_unit->OpType()
<< "] index: [" << node_unit->Index()
<< "] name: [" << node_unit->Name()
<< "]";
}

if (supported) {
// If the node_unit is supported, add all of its nodes to the supported list.
for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
supported_nodes.insert(node_in_group);
}
}

node_unit_supported_result.insert({node_unit, supported});
}

return supported_nodes;
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider {

private:
bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
const logging::Logger& logger) const;

std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Loading
Loading