From 5ced778b0bc7a2e7a9e37da1792f7c632027701d Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 13 Feb 2024 14:40:48 -0800 Subject: [PATCH] QNN EP: Fuse certain DQ -> Q sequences into a single QNN Convert operator. --- .../optimizer/qdq_transformer/qdq_util.cc | 43 ++++++++ .../core/optimizer/qdq_transformer/qdq_util.h | 12 ++ .../qnn/builder/op_builder_factory.h | 23 ++++ .../builder/opbuilder/convert_op_builder.cc | 103 ++++++++++++++++++ .../core/providers/qnn/builder/qnn_model.cc | 35 +++++- .../providers/qnn/qnn_execution_provider.cc | 88 +++++++++------ .../providers/qnn/qnn_execution_provider.h | 1 - .../test/providers/qnn/simple_op_htp_test.cc | 55 ++++++++++ 8 files changed, 319 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index d95e2baa9457f..4a9106f0c06af 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +struct HandleConvertResult { + Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine + // whether a DQ -> Q sequence was successfully merged into a Convert. + const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence. + // Set to nullptr if this node unit could not be merged into a Convert. +}; + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer + * to the Q node unit that was successfully merged with the provided DQ node unit. + */ +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc new file mode 100644 index 0000000000000..977a9e0b3d9d0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" + +#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). + +namespace onnxruntime { +namespace qnn { + +class ConvertOpBuilder : public BaseOpBuilder { + public: + ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder); + + Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; +}; + +Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector input_names; + + // Process the input from the DQ node + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names)); + + // Process the output from the Q node. Override the QNN operator type to "Convert". + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {}, + logger, do_op_validation, QNN_OP_CONVERT)); + return Status::OK(); +} + +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Looking for a standalone DQ to start the sequence. + if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + const Node& dq_node = maybe_dq_node_unit.GetNode(); + + // DQ must have a single Q child. DQ must not produce a graph output. + auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); + if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return {}; + } + + const Node& q_node = *children[0]; + const auto q_node_unit_it = node_unit_map.find(&q_node); + + if (q_node_unit_it == node_unit_map.end()) { + return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr}; + } + + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return {}; + } + + ConvertOpBuilder op_builder; + + LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger, + do_op_validation); + return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr}; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 314cab4a36ca9..dc91b9dfa199e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + std::unordered_set handled_node_units; + // Op builer const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // Check whether it's part of NodeUnit const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node) + // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) const std::string& op_type = node_unit.OpType(); + + if (node != &node_unit.GetNode()) { + continue; + } + + if (handled_node_units.count(&node_unit) != 0) { + continue; // Already handled. + } + + // Try to convert particular DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + node_unit, + node_unit_map, + logger_, + false /*do_op_validation*/); + ORT_RETURN_IF_ERROR(convert_result.status); + + if (convert_result.q_node_unit) { + // Successfully merged DQ -> Q sequence into a QNN Convert op. + // Mark both of these node units as handled. + handled_node_units.insert(&node_unit); + handled_node_units.insert(convert_result.q_node_unit); + continue; + } + LOGS(logger_, VERBOSE) << " node name: [" << node->Name() << "] node optype: [" << op_type << "] as part of the NodeUnit type: [" << node_unit.OpType() << "] name: [" << node_unit.Name() << "]"; - if (node != &node_unit.GetNode()) { - continue; - } - if (const auto* op_builder = GetOpBuilder(op_type)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); } + + handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..f5a166d36b15a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { - // If we have visited one of the nodes in the node_unit, use the result directly - const auto it = node_unit_supported_result.find(&node_unit); - if (it != node_unit_supported_result.cend()) { - return it->second; + const std::string& op_type = node_unit.OpType(); + bool supported = false; + const auto* op_builder = qnn::GetOpBuilder(op_type); + if (op_builder == nullptr) { + LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." + << node_unit.OpType() << " node `" << node_unit.Name() + << "` will not be assigned to QNN EP."; } else { - const std::string& op_type = node_unit.OpType(); - - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); + auto status = op_builder->IsOpSupported(qnn_model_wrapper, + node_unit, logger); + if (Status::OK() != status) { + LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); } - node_unit_supported_result[&node_unit] = supported; - return supported; + supported = (Status::OK() == status); } + return supported; } std::unordered_set @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, if (node != &node_unit->GetNode()) { continue; } - const bool supported = IsNodeSupported(qnn_model_wrapper, - *node_unit, - node_unit_supported_result, - logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; + + if (node_unit_supported_result.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + // Try to convert certain standalone DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + *node_unit, + node_unit_map, + logger, + true /*do_op_validation*/); + if (!convert_result.status.IsOK()) { + LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. " + << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", " + << "Message: " << convert_result.status.ErrorMessage(); + } + + bool supported = false; + + if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op + supported = true; + + // Mark the Q node unit as handled and supported here so that we don't try to process it again. + node_unit_supported_result.insert({convert_result.q_node_unit, true}); + supported_nodes.insert(&convert_result.q_node_unit->GetNode()); + } else { + supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); + LOGS(logger, VERBOSE) << "Node supported: [" << supported + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + } + if (supported) { // If the node_unit is supported, add all of its nodes to the supported list. for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node_in_group); } } + + node_unit_supported_result.insert({node_unit, supported}); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..0bcaa39b22f6d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2f3b0e84a123e..a6422407d79fd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { kOnnxDomain, true); } + +static GetTestQDQModelFn BuildQDQConvertAddTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def) { + return [input0_def, input1_def](ModelTestBuilder& builder, std::vector>& output_qparams) { + constexpr bool use_contrib_qdq = true; + + // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_u8_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_u8_qparams.scale, + input0_u8_qparams.zero_point, use_contrib_qdq); + + // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float) + QuantParams input0_u16_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_convert = AddQDQNodePair(builder, input0_after_qdq, input0_u16_qparams.scale, + input0_u16_qparams.zero_point, use_contrib_qdq); + + // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test quantization type conversion (mixed precision) with Add. +// First input is converted from uint8_t to uint16_t. +TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8); + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildOpTestCase("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain), + BuildQDQConvertAddTestCase(input0_def, input1_def), + provider_options, + 18, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test