From 0e708de4fcec8496cc11650c2dc089d10bc052e0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 2 Aug 2024 11:02:22 -0700 Subject: [PATCH] [QNN EP] Support Conv + Clip/Relu fusion (#21537) ### Description - Supports quantized Conv + Activation on the HTP backend: - Translates `DQs -> Conv -> Relu/Clip -> Q` into a single QNN Conv operator if the Relu (or Clip) are redundant. ### Motivation and Context Expands support for QDQ models created with tools that do not wrap Relu or Clip with QDQ nodes. This PR introduces the `IQnnNodeGroup` class. In the same way that a `NodeUnit` represents a collection of `Nodes`, a `IQnnNodeGroup` can represent one or more `NodeUnits` that are translated into a QNN operator. QNN EP parses the ONNX graph to create a list of `IQnnNodeGroup` objects, each representing a single `NodeUnit` or a fusion of multiple `NodeUnits`. --- onnxruntime/core/framework/node_unit.cc | 15 + onnxruntime/core/framework/node_unit.h | 4 + .../core/providers/qnn/builder/qnn_fusions.cc | 294 ----------- .../core/providers/qnn/builder/qnn_fusions.h | 38 -- .../core/providers/qnn/builder/qnn_model.cc | 51 +- .../qnn/builder/qnn_model_wrapper.cc | 2 + .../providers/qnn/builder/qnn_node_group.h | 68 +++ .../qnn_node_group/conv_activation_fusion.cc | 480 ++++++++++++++++++ .../qnn_node_group/conv_activation_fusion.h | 63 +++ .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 179 +++++++ .../qnn/builder/qnn_node_group/dq_q_fusion.h | 57 +++ .../qnn_node_group/hardsigmoid_mul_fusion.cc | 144 ++++++ .../qnn_node_group/hardsigmoid_mul_fusion.h | 57 +++ .../builder/qnn_node_group/qnn_node_group.cc | 221 ++++++++ .../qnn/builder/qnn_node_group/utils.cc | 66 +++ .../qnn/builder/qnn_node_group/utils.h | 40 ++ .../providers/qnn/qnn_execution_provider.cc | 121 ++--- .../providers/qnn/qnn_execution_provider.h | 3 - onnxruntime/test/providers/qnn/conv_test.cc | 252 +++++++-- 19 files changed, 1670 insertions(+), 485 deletions(-) delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_fusions.cc delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_fusions.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index e2c06fbdfa621..850cb167a3ece 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "node_unit.h" +#include #include "core/graph/graph_viewer.h" namespace onnxruntime { @@ -272,6 +273,20 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges) + : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), + target_node_(target_node), + q_nodes_(q_nodes.begin(), q_nodes.end()), + type_(unit_type), + inputs_(inputs.begin(), inputs.end()), + outputs_(outputs.begin(), outputs.end()), + input_edge_count_(input_edge_count), + output_edges_(std::move(output_edges)) { +} + const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index e84e62479162f..50bd423d2f547 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,6 +68,10 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc deleted file mode 100644 index b04075f11203c..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_fusions.h" - -#include -#include -#include -#include -#include -#include -#include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" - -#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ - return Status::OK(); \ - } \ - } while (0) - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param start_node_unit The node unit that could potentially start the DQ -> Q sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -static Status TryHandleConvertSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - - // Looking for a standalone DQ to start the sequence. - if (start_node_unit.OpType() != QDQ::DQOpName || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - const Node& dq_node = start_node_unit.GetNode(); - - // DQ must have a single Q child. DQ must not produce a graph output. - auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); - if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); - } - - const Node& q_node = *children[0]; - const auto q_node_unit_it = node_unit_map.find(&q_node); - - ORT_RETURN_IF(q_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - - const NodeUnit* q_node_unit = q_node_unit_it->second; - - // Check if Q node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(q_node_unit) != 0) { - return Status::OK(); - } - - // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); - }; - - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() - << "]"; - - // Add a QNN Convert to the model. Get the input from the DQ node, and the output from the Q node. - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(*q_node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused Convert node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(q_node_unit); - - return Status::OK(); -} - -/** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -static Status TryHandleHardSigmoidSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { - // Looking for a standalone HardSigmoid to start the sequence. - if (start_node_unit.OpType() != "HardSigmoid" || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - NodeAttrHelper hs_attr_helper(start_node_unit); - float alpha = hs_attr_helper.Get("alpha", 0.2f); - float beta = hs_attr_helper.Get("beta", 0.5f); - constexpr float req_alpha = 1.0f / 6.0f; - constexpr float req_beta = 0.5f; - constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; - constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; - - // Check for explicit values of alpha and beta. - if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return Status::OK(); - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = start_node_unit.GetNode(); - - // HardSigmoid must have a single Mul child. HardSigmoid must not produce a graph output. - auto children = graph_utils::FindChildrenByType(hs_node, "Mul"); - if (children.size() != 1 || hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return Status::OK(); - } - - const Node& mul_node = *children[0]; - const auto mul_node_unit_it = node_unit_map.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - const NodeUnit* mul_node_unit = mul_node_unit_it->second; - - // Check if Mul node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(mul_node_unit) != 0) { - return Status::OK(); - } - - // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - // Input to HardSigmoid must also be the other input to the Mul. - auto& hs_input_name = start_node_unit.Inputs()[0].node_arg.Name(); - const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || - mul_node.InputDefs()[1]->Name() == hs_input_name; - - if (!same_root_input) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << start_node_unit.Name() - << "] Mul name: [" << mul_node_unit->Name() << "]"; - - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused HardSwish node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(mul_node_unit); - - return Status::OK(); -} - -using FusionFunc = Status (*)(std::vector&, - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_set&, - const logging::Logger&, - bool); - -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool validate) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", TryHandleConvertSequence}, - {"HardSigmoid", TryHandleHardSigmoidSequence}, - }; - - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto iter = fusions.find(starting_node.OpType()); - if (iter != fusions.end()) { - fused_nodes.clear(); - - FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_nodes, qnn_model_wrapper, starting_node, node_unit_map, - handled_node_units, logger, validate)); - } - - return Status::OK(); -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h deleted file mode 100644 index 39e2e71c01d8c..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Tries to fuse a node sequence starting from the given starting node. Should be called in a topologically ordered - * walk of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation); -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 503943dfb636b..83f9184d33611 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -7,7 +7,7 @@ #include "QnnOpDef.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_fusions.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -117,49 +117,20 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } - std::unordered_set handled_node_units; + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_holder.size()); - // Op builer - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer.GetNode(node_indices[i])); + ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, + node_unit_holder.size(), logger_)); - // Check whether it's part of NodeUnit - const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); - // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) - const std::string& op_type = node_unit.OpType(); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); - if (node != &node_unit.GetNode()) { - continue; + if (!status.IsOK()) { + LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; + return status; } - - if (handled_node_units.count(&node_unit) != 0) { - continue; // Already handled. - } - - // Try to see if this node unit can be fused. - std::vector fused_nodes; - ORT_RETURN_IF_ERROR(TryFusions(fused_nodes, qnn_model_wrapper, node_unit, node_unit_map, - handled_node_units, logger_, false /*do_op_validation*/)); - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - } - continue; - } - - LOGS(logger_, VERBOSE) << " node name: [" << node->Name() - << "] node optype: [" << op_type - << "] as part of the NodeUnit type: [" << node_unit.OpType() - << "] name: [" << node_unit.Name() - << "]"; - if (const auto* op_builder = GetOpBuilder(op_type)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); - } - - handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 9d3f460572d84..657224f68f71b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -239,6 +239,8 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::string error_msg; bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); if (!rt) { + // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more + // specific validation error (instead of "failed to add node"). LOGS(logger_, WARNING) << error_msg; } return rt; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h new file mode 100644 index 0000000000000..f9ef01411310f --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/logging/logging.h" +#include "core/framework/node_unit.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a group of NodeUnits that QNN EP translates into a core QNN operator. Can represent a single NodeUnit +/// or a fusion of multiple NodeUnits (e.g., DQ* -> Conv -> Relu -> Q). +/// +class IQnnNodeGroup { + public: + virtual ~IQnnNodeGroup() = default; + + // Returns an OK status if this IQnnNodeGroup is supported by QNN. + virtual Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Adds this IQnnNodeGroup to the QNN model wrapper. + virtual Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Returns a list of NodeUnits contained by this IQnnNodeGroup. + virtual gsl::span GetNodeUnits() const = 0; + + /// + /// Returns the "target" NodeUnit of the group. This is important for topological ordering of IQnnNodeGroups. + /// The target should be the first NodeUnit where all input paths (of the IQnnNodeGroup) converge. + /// For example, "Conv" should be the target NodeUnit for the following IQnnNodeGroup with 6 NodeUnits. + /// input0 -> DQ -> Conv -> Relu -> Q + /// ^ + /// | + /// input1 -> DQ ----+ + /// + /// Target NodeUnit in IQnnNodeGroup + virtual const NodeUnit* GetTargetNodeUnit() const = 0; + + // Returns a string representation of the IQnnNodeGroup's type. + virtual std::string_view Type() const = 0; +}; + +/// +/// Traverses the ONNX graph to create IQnnNodeGroup objects, each containing one or more NodeUnits. +/// The returned IQnnNodeGroup objects are sorted in topological order. +/// +/// Output vector into which the resulting IQnnNodeGroup objects are stored. +/// Contains reference to the ONNX GraphViewer and used for validaton on QNN +/// Maps a Node* to a NodeUnit* +/// The number of NodeUnits in the ONNX graph. +/// Logger +/// Status with potential error +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + size_t num_node_units, + const logging::Logger& logger); +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc new file mode 100644 index 0000000000000..813bba8a5952b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -0,0 +1,480 @@ +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. +static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& scale, + /*out*/ int32_t& zero_point, + /*out*/ int32_t& zp_data_type) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& q_inputs = q_node_unit.GetNode().InputDefs(); + + // Require an explicit zero-point input for now. + if (q_inputs.size() != 3 || !q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Exists()) { + return false; + } + + std::vector zero_points; + Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Name(), + zero_points, zp_data_type); + + // Should only have one zero-point (per-tensor). + if (!status.IsOK() || zero_points.size() != 1) { + return false; + } + zero_point = -zero_points[0]; // QNN zero-points are negated. + + std::vector scales; + status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ_SCALE_INPUT_IDX]->Name(), scales); + + // Should only have one scale (per-tensor). + if (!status.IsOK() || scales.size() != 1) { + return false; + } + + scale = scales[0]; + return true; +} + +// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. +static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& rmin, + /*out*/ float& rmax) { + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + default: + return false; + } + + return true; +} + +// Returns true if the Clip in the sequence (Clip -> Q) can be removed because it is made redundant by the Q. +static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QUANTIZE_LINEAR); + float rmin = 0.0f; + float rmax = 0.0f; + + if (!GetQRminRmax(qnn_model_wrapper, q_node_unit, rmin, rmax)) { + return false; + } + + float clip_min = std::numeric_limits::lowest(); + float clip_max = std::numeric_limits::max(); + + if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), + clip_min, clip_max, logger)) { + return false; + } + + // The clip range must entirely overlap the quantization range (quantization can be smaller). + // Clip range: [------------------] + // Quant range: [-------------] + constexpr float epsilon = std::numeric_limits::epsilon(); + if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { + return false; + } + + return true; +} + +// Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + // Relu is redundant if the zero-point is set to the smallest quantized value. + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: + return zero_point == static_cast(std::numeric_limits::lowest()); + default: + return false; + } +} + +// Returns true if the Clip/Relu in the sequence (Clip/Relu -> Q) can be removed because it is made redundant by the Q. +static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + const std::string& activation_type = activation_node_unit.OpType(); + + if (activation_type == "Relu") { + return CanQRelaceRelu(qnn_model_wrapper, q_node_unit); + } + + if (activation_type == "Clip") { + return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit, logger); + } + + return false; +} + +// Returns the parent DQ nodes for a given node. +static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { + // Get all parent DQ nodes sorted by destination argument index. + std::vector parents(node.InputDefs().size(), nullptr); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { + if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { + parents[it->GetDstArgIndex()] = &(it->GetNode()); + } + } + + // Remove all the nodes which are not in the graph_viewer + parents.erase(std::remove_if(parents.begin(), parents.end(), + [&graph_viewer](const Node* _node) { + return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; + }), + parents.end()); + + return parents; +} + +// Gets the parent DQ nodes for the given Conv node. This fuction checks that the DQs are not a part of +// any other NodeUnit and that every Conv input comes from a parent DQ. +static bool GetConvDQs( + const GraphViewer& graph_viewer, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Node& conv_node, + /*out*/ std::array& dq_node_units) { + if (conv_node.OpType() != "Conv" && conv_node.OpType() != "ConvTranspose") { + return false; + } + + // Count number of inputs to Conv node. + const auto& conv_inputs = conv_node.InputDefs(); + const size_t num_conv_inputs = std::count_if(conv_inputs.cbegin(), conv_inputs.cend(), + [](const NodeArg* input) { return input && input->Exists(); }); + + // Get the Conv's parent DQ nodes. + std::vector dq_nodes = FindParentDQNodes(graph_viewer, conv_node); + const size_t num_dqs = dq_nodes.size(); + + // Within a QDQ node group, a target node input is the only consumer of each DQ. + if ((num_conv_inputs != num_dqs) || (num_dqs > dq_node_units.size())) { + return false; + } + + dq_node_units.fill(nullptr); + for (size_t i = 0; i < num_dqs; i++) { + const Node* dq_node = dq_nodes[i]; + + // DQ must not produce a graph output. + if (!dq_node || graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return false; + } + + // Conv should be the only consumer of a parent DQ. + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); + if (!dq_has_single_output_edge_to_target) { + return false; + } + + // DQ node must be part of a "standalone" NodeUnit. + const auto it = node_to_node_unit.find(dq_node); + if (it == node_to_node_unit.end()) { + return false; + } + const NodeUnit* dq_node_unit = it->second; + if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { + return false; + } + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return false; + } + + dq_node_units[i] = dq_node_unit; + } + + return true; +} + +// Checks that the input and output data types are valid for a QDQ Conv. +static bool CheckQDQConvDataTypes(std::array& dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QUANTIZE_LINEAR); + // input and output types need to be same + int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; + } + + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return false; + } + } + + if (dq_node_units[2] != nullptr) { // has bias + int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } + } + + return true; +} + +// Utility function to either validate or create a quantized QNN Conv node. The function creates a temporary +// custom NodeUnit that excludes the Clip/Relu because it is redundant. This custom NodeUnit is passed to our +// existing Conv OpBuilder for creation or validation via QNN APIs. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { + const size_t num_dqs = dq_node_units.size(); + constexpr size_t max_num_dqs = 3; + ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == QUANTIZE_LINEAR, + "Expected Conv/ConvTranspose and QuantizeLinear but got ", conv_node_unit->OpType(), " and ", + q_node_unit->OpType()); + + std::array dq_nodes_buf = {}; + for (size_t i = 0; i < num_dqs; i++) { + dq_nodes_buf[i] = &dq_node_units[i]->GetNode(); + } + gsl::span dq_nodes(dq_nodes_buf.data(), num_dqs); + + std::array q_nodes = {&q_node_unit->GetNode()}; + const Node& target_node = conv_node_unit->GetNode(); + + // Populate NodeUnit inputs + std::vector inputs; + inputs.reserve(num_dqs); + for (const Node* dq_node : dq_nodes) { + const auto dq_inputs = dq_node->InputDefs(); + const auto& dq_attrs = dq_node->GetAttributes(); + + std::optional axis; + if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; + inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); + } + + // Populate NodeUnit outputs and output edges + std::vector outputs; + Node::EdgeSet output_edges; + for (const Node* q_node : q_nodes) { + const auto q_inputs = q_node->InputDefs(); + const auto& q_attrs = q_node->GetAttributes(); + const auto q_outputs = q_node->OutputDefs(); + + std::optional axis; + if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; + outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + + // Gather output edges out of the Q node. + auto q_cur_edge = q_node->OutputEdgesBegin(); + auto q_end_edge = q_node->OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + } + } + + NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, num_dqs, output_edges); + const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); + if (conv_op_builder == nullptr) { + return Status::OK(); + } + + if (validate) { + return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); + } + + return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); +} + +// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. +// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. +std::unique_ptr ConvActivationFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Expect that this function is called with a standalone Conv or ConvTranspose. + const auto& conv_type = conv_node_unit.OpType(); + + if ((conv_type != "Conv" && conv_type != "ConvTranspose") || + (conv_node_unit.UnitType() != NodeUnit::Type::SingleNode)) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Conv must have a single Relu or Clip child. + const std::array activation_op_types = {"Relu", "Clip"}; + const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + if (activation_node_unit == nullptr) { + return nullptr; + } + + // Relu/Clip must have a single Q child. + const std::array q_op_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. + if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit, logger)) { + return nullptr; + } + + // Create a QDQ node group with DQ* -> Conv -> Q + const Node& conv_node = conv_node_unit.GetNode(); + std::array dq_node_units = {}; + if (!GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node, dq_node_units)) { + return nullptr; + } + + if (!CheckQDQConvDataTypes(dq_node_units, q_node_unit)) { + return nullptr; + } + + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units[2], + conv_node_unit, + *activation_node_unit, + *q_node_unit); +} + +ConvActivationFusion::ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) + : node_units_{} { + size_t i = 0; + node_units_[i++] = &dq_node_unit_0; + node_units_[i++] = &dq_node_unit_1; + if (dq_node_unit_2 != nullptr) { + node_units_[i++] = dq_node_unit_2; + } + node_units_[i++] = &conv_node_unit; + node_units_[i++] = &activation_node_unit; + node_units_[i++] = &q_node_unit; + assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); +} + +Status ConvActivationFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return ValidateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +Status ConvActivationFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); + + return CreateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); +} + +gsl::span ConvActivationFusion::GetNodeUnits() const { + const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; + return gsl::make_span(node_units_.data(), num_node_units); +} + +const NodeUnit* ConvActivationFusion::GetTargetNodeUnit() const { + const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; + return node_units_[conv_index]; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h new file mode 100644 index 0000000000000..b604b25e943e6 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ* -> Conv -> Relu/Clip -> Q sequence where the Relu (or Clip) is redundant +/// due to the quantization effects of the Q. This sequence is translated to a quantized QNN Conv. +/// All contained NodeUnits are of type SingleNode since they are not a part of an existing QDQ node unit. +/// +class ConvActivationFusion : public IQnnNodeGroup { + public: + ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ConvActivationFusion"; } + + /// + /// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. + /// + /// Used for validation and to traverse/query the graph + /// Conv node unit (type SingleNode) that be part of the sequence. + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; // Last elem is nullptr if the optional bias DQ is missing. +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc new file mode 100644 index 0000000000000..ce87ac4a3d21c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -0,0 +1,179 @@ +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, bool validate); +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node); + +std::unique_ptr DQQFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + // Expect that this function is called with a standalone DQ. + if (dq_node_unit.OpType() != DEQUANTIZE_LINEAR || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single Q child (1 output edge) and must not produce a graph output. + const std::array child_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // DQ and Q must have equal scale type and different zp type. + if (!IsDQQConversion(graph_viewer, dq_node, q_node_unit->GetNode())) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, dq_node_unit, *q_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(dq_node_unit, *q_node_unit); +} + +DQQFusion::DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) + : node_units_{&dq_node_unit, &q_node_unit} { +} + +Status DQQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status DQQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span DQQFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* DQQFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + bool validate) { + assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR); + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused Convert node."); + } + + return Status::OK(); +} + +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + auto is_scalar_shape = [](const NodeArg& input_arg) -> bool { + auto shape = input_arg.Shape(); + if (shape == nullptr) { + return false; + } + + auto dim_size = shape->dim_size(); + return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); + }; + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != QDQ_MAX_NUM_INPUTS || + q_input_defs.size() != QDQ_MAX_NUM_INPUTS || + !is_scalar_shape(*q_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // All TensorProtos must have a data type + if (!q_zp_tensor_proto->has_data_type() || !dq_zp_tensor_proto->has_data_type() || + !q_scale_tensor_proto->has_data_type() || !dq_scale_tensor_proto->has_data_type()) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && + (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h new file mode 100644 index 0000000000000..90fe44c3af059 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ -> Q sequence that converts from one quantization type (e.g., uint8_t) to +/// another (e.g., uint16_t). This is translated into a QNN Convert operator, which is much faster than individual +/// ops. The DQ and Q are standalone NodeUnits that are not part of a QDQ node unit. +/// +class DQQFusion : public IQnnNodeGroup { + public: + DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(DQQFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "DQQFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid DQ -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the DQ and Q NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// DQ node unit that could start the DQ -> Q sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc new file mode 100644 index 0000000000000..76b1726646486 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -0,0 +1,144 @@ +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, bool validate); + +std::unique_ptr HardSigmoidMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Looking for a standalone HardSigmoid to start the sequence. + if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || + hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); + float alpha = hs_attr_helper.Get("alpha", 0.2f); + float beta = hs_attr_helper.Get("beta", 0.5f); + constexpr float req_alpha = 1.0f / 6.0f; + constexpr float req_beta = 0.5f; + constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; + constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; + + // Check for explicit values of alpha and beta. + if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { + return nullptr; + } + + // HardSigmoid must have a single Mul child (1 output edge) and must not produce a graph output. + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::array child_types = {"Mul"}; + const NodeUnit* mul_node_unit = GetOnlyChildOfType(graph_viewer, hardsigmoid_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (mul_node_unit == nullptr) { + return nullptr; + } + + // Input to HardSigmoid must also be the other input to the Mul. + const Node& mul_node = mul_node_unit->GetNode(); + auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); + const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || + mul_node.InputDefs()[1]->Name() == hs_input_name; + + if (!same_root_input) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); +} + +HardSigmoidMulFusion::HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) + : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { +} + +Status HardSigmoidMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status HardSigmoidMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span HardSigmoidMulFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* HardSigmoidMulFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + bool validate) { + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h new file mode 100644 index 0000000000000..3b67f13492a46 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a HardSigmoid -> Mul sequence that computes `x * HardSigmoid(x)`. +/// This is translated into a QNN HardSwish operator. +/// The contained NodeUnits are of type SingleNode since they are not a part of a QDQ node unit. +/// +class HardSigmoidMulFusion : public IQnnNodeGroup { + public: + HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(HardSigmoidMulFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "HardSigmoidMulFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid HardSigmoid -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the HardSigmoid and Mul NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// HardSigmoid node unit that could start the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc new file mode 100644 index 0000000000000..9fb9e815321c0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +/// +/// A IQnnNodeGroup class that wraps a single NodeUnit. Most NodeUnits in the ONNX graph will +/// be wrapped by this class. +/// +class QnnNodeUnitWrapper : public IQnnNodeGroup { + public: + explicit QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit_->Name(), "` will not be assigned to QNN EP."); + + return op_builder->IsOpSupported(qmw, *node_unit_, logger); + } + + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_->OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); + return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false); + } + + gsl::span GetNodeUnits() const override { + return gsl::span{&node_unit_, 1ULL}; + } + + const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } + std::string_view Type() const override { return "NodeUnit"; } + + private: + const NodeUnit* node_unit_; +}; + +/// +/// The type of a function that tries to fuse NodeUnits into a IQnnNodeGroup. +/// +using FusionFunc = std::unique_ptr (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +/// +/// Given a starting NodeUnit, this function tries all possible fusions that start with that NodeUnit. +/// If successful, returns a IQnnNodeGroup object that represents the fusion of various NodeUnits. +/// Currently only handles standalone NodeUnits that are not in a QDQ unit but that can change in the future. +/// +/// QnnModelWrapper that contains the ONNX GraphViewer. Used for validation. +/// NodeUnit that potentially starts a fusion. +/// Maps a Node* to a NodeUnit* +/// Maps a NodeUnit* to a IQnnNodeGroup* +/// +/// IQnnNodeGroup representing the fusion or an empty std::unique_ptr +static std::unique_ptr TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Maps a starting operator type to the fusion function. + static std::unordered_map fusions = { + {"DequantizeLinear", DQQFusion::TryFusion}, + {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, + {"Conv", ConvActivationFusion::TryFusion}, + {"ConvTranspose", ConvActivationFusion::TryFusion}, + }; + + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + auto iter = fusions.find(starting_node_unit.OpType()); + if (iter != fusions.end()) { + FusionFunc fusion_func = iter->second; + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); + } + return nullptr; +} + +// Traverses the ONNX Graph and groups NodeUnits into IQnnNodeGroup objects. Some IQnnNodeGroup objects +// represent a fusion of various NodeUnits. This function generates a vector of indices that +// represent the topological order of the qnn_node_groups. +static Status GetQnnNodeGroupsImpl(/*out*/ std::vector>& qnn_node_groups, + /*out*/ std::vector& sorted_qnn_node_group_indices, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + sorted_qnn_node_group_indices.reserve(num_node_units); + qnn_node_groups.reserve(num_node_units); + + std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; + std::vector> sorted_node_units; + sorted_node_units.reserve(num_node_units); + + // Process just the fusions of NodeUnits first to ensure a correct topological order of all IQnnNodeGroups. + // This is the same approach taken by ORT utilities for grouping Nodes into NodeUnits. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } + + sorted_node_units.push_back(node_unit); + + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); + + if (fused_node_group) { + const size_t index = qnn_node_groups.size(); + fused_qnn_node_group_indices[fused_node_group.get()] = index; + + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); + } + + qnn_node_groups.push_back(std::move(fused_node_group)); + } + } + + // Create IQnnNodeGroups for the leftover NodeUnits that were not fused. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + + if (it != node_unit_to_qnn_node_group.end()) { + // Already added this NodeUnit to a IQnnNodeGroup, so we'll skip it. + // However, if this NodeUnit is the "target" for the IQnnNodeGroup, then add its index to + // the sorted list of indices. + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); + } + continue; + } + + const size_t index = qnn_node_groups.size(); + auto qnn_node_group = std::make_unique(*node_unit); + + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + qnn_node_groups.push_back(std::move(qnn_node_group)); + sorted_qnn_node_group_indices.push_back(index); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} + +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + std::vector sorted_qnn_node_group_indices; + std::vector> qnn_node_groups_holder; + ORT_RETURN_IF_ERROR(GetQnnNodeGroupsImpl(qnn_node_groups_holder, sorted_qnn_node_group_indices, qnn_model_wrapper, + node_to_node_unit, num_node_units, logger)); + + // Move IQnnNodeGroups to the output std::vector in sorted (topological) order. + qnn_node_groups.resize(0); + qnn_node_groups.reserve(qnn_node_groups_holder.size()); + for (auto index : sorted_qnn_node_group_indices) { + assert(index < qnn_node_groups_holder.size()); + std::unique_ptr qnn_node_group = std::move(qnn_node_groups_holder[index]); + qnn_node_groups.push_back(std::move(qnn_node_group)); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc new file mode 100644 index 0000000000000..5548d7d37c378 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -0,0 +1,66 @@ +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node& parent_node = parent_node_unit.GetNode(); + + // Parent must have a single child (1 output edge) and must not produce a graph output. + if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { + return nullptr; + } + + // Child must be of a valid type. + const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } + const std::string& child_type = child_node.OpType(); + bool is_valid_child_type = false; + + for (const auto& valid_op_type : child_op_types) { + if (valid_op_type == child_type) { + is_valid_child_type = true; + break; + } + } + + if (!is_valid_child_type) { + return nullptr; + } + + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(child_node_unit) != 0) { + return nullptr; + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + return child_node_unit; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h new file mode 100644 index 0000000000000..0d11d21906ccb --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { +constexpr const char* QUANTIZE_LINEAR = "QuantizeLinear"; +constexpr const char* DEQUANTIZE_LINEAR = "DequantizeLinear"; +constexpr size_t QDQ_MAX_NUM_INPUTS = 3; +constexpr size_t QDQ_SCALE_INPUT_IDX = 1; +constexpr size_t QDQ_ZERO_POINT_INPUT_IDX = 2; + +/// +/// Utility function to get a child NodeUnit. The returned NodeUnit must be the parent's only child, must be +/// of the expected type, and must not be a part of another IQnnNodeGroup. +/// +/// GraphViewer containing all Nodes +/// Parent NodeUnit +/// Valid child types +/// Maps a Node to its NodeUnit +/// Maps a NodeUnit to its IQnnNodeGroup. +/// Used to check that the child has not already been added to another IQnnNodeGroup. +/// +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& node_unit_to_qnn_node_group); + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index c56a47e67497e..fc64d63ede338 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,10 +16,10 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_fusions.h" #include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/framework/run_options.h" @@ -412,25 +412,35 @@ QNNExecutionProvider::~QNNExecutionProvider() { #endif } -bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const { - const std::string& op_type = node_unit.OpType(); - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); +// Logs information about the supported/unsupported nodes. +static void LogNodeSupport(const logging::Logger& logger, + logging::Severity log_severity, + logging::DataType log_data_type, + const onnxruntime::CodeLocation& call_site, + const qnn::IQnnNodeGroup& qnn_node_group, + Status support_status) { + if (!logger.OutputIsEnabled(log_severity, log_data_type)) { + return; + } + + std::ostringstream oss; + oss << (support_status.IsOK() ? "Validation PASSED " : "Validation FAILED ") << "for nodes (" + << qnn_node_group.Type() << "):" << std::endl; + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + oss << "\tOperator type: " << node->OpType() + << " Node name: " << node->Name() + << " Node index: " << node->Index() << std::endl; } - supported = (Status::OK() == status); } - return supported; + if (!support_status.IsOK()) { + oss << "\tREASON : " << support_status.ErrorMessage() << std::endl; + } + + logging::Capture(logger, log_severity, logging::Category::onnxruntime, + log_data_type, call_site) + .Stream() + << oss.str(); } std::unordered_set @@ -469,68 +479,33 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - std::unordered_set handled_node_units; - handled_node_units.reserve(node_unit_size); - - auto add_supported_nodes = [](std::unordered_set& supported_nodes, const NodeUnit* node_unit) { - for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { - supported_nodes.insert(node_in_group); - } - }; - - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - gsl::not_null node(graph_viewer.GetNode(node_indices[i])); - - // Get the node_unit associated with the node. Note that the node may not be the node_unit's target node. - const NodeUnit* node_unit = node_unit_map.at(node); - - // Visiting 'nodes' in topological order does not guarantee that 'node_units' are - // also visited in topological order. Skip this node if it is not the node_unit's target node - // to ensure 'node_units' are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; - } + std::vector> qnn_node_groups; + qnn_node_groups.reserve(node_unit_size); - if (handled_node_units.count(node_unit) != 0) { - continue; // Already handled this node unit - } + if (Status status = qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, + node_unit_map, node_unit_size, logger); + !status.IsOK()) { + LOGS(logger, ERROR) << status.ErrorMessage(); + return {}; + } - // Try to see if this node unit can be fused. - std::vector fused_nodes; - Status fusion_status = TryFusions(fused_nodes, qnn_model_wrapper, *node_unit, node_unit_map, - handled_node_units, logger, true /*do_op_validation*/); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->IsSupported(qnn_model_wrapper, logger); + const bool supported = status.IsOK(); - if (!fusion_status.IsOK()) { - LOGS(logger, WARNING) << "Failed to apply fusion: " << fusion_status.ErrorMessage(); - handled_node_units.insert(node_unit); - continue; - } - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - add_supported_nodes(supported_nodes, fused_node_unit); - } - continue; + constexpr auto log_severity = logging::Severity::kVERBOSE; + constexpr auto log_data_type = logging::DataType::SYSTEM; + if (logger.OutputIsEnabled(log_severity, log_data_type)) { + LogNodeSupport(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, status); } - // Couldn't fuse the node unit. See if it is supported by itself. - const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; - if (supported) { - add_supported_nodes(supported_nodes, node_unit); + for (const NodeUnit* node_unit : qnn_node_group->GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node); + } + } } - - handled_node_units.insert(node_unit); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index f00ffb6cfdb96..4c48370492ef7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -53,9 +53,6 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; private: - bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const; - std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 35889c9fa2307..95673586677ef 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace test { +// Information for activation node placed between the Conv and Q. +struct OutputActivationInfo { + std::string op_type; // Relu or Clip + std::vector const_inputs; +}; + // Creates a graph with a single float32 Conv operator. Used for testing CPU backend. static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, @@ -23,9 +29,10 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, std::optional group, - const std::string& auto_pad = "NOTSET") { + const std::string& auto_pad = "NOTSET", + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad](ModelTestBuilder& builder) { + dilations, group, auto_pad, output_activation](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -34,9 +41,9 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons conv_inputs.push_back(MakeTestInput(builder, bias_def)); } - auto* output = builder.MakeOutput(); + auto* conv_output = output_activation.has_value() ? builder.MakeIntermediate() : builder.MakeOutput(); - Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {output}); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {conv_output}); conv_node.AddAttribute("auto_pad", auto_pad); if (group.has_value()) { @@ -54,6 +61,15 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons if (!dilations.empty()) { conv_node.AddAttribute("dilations", dilations); } + + if (output_activation.has_value()) { + NodeArg* output = builder.MakeOutput(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {output}); + } }; } @@ -88,19 +104,22 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, group, auto_pad, + use_contrib_qdq, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -144,27 +163,39 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } template -static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - int64_t weight_quant_axis, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQPerChannelConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + int64_t weight_quant_axis, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, group, auto_pad, use_contrib_qdq, - weight_quant_axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { + weight_quant_axis, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -248,7 +279,17 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -267,7 +308,8 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -277,10 +319,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef #endif TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad), + group, auto_pad, output_activation), BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad, use_contrib_qdq), + group, auto_pad, use_contrib_qdq, + output_activation), provider_options, opset, expected_ep_assignment, @@ -302,7 +345,8 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -312,11 +356,11 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te #endif auto f32_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad); + group, auto_pad, output_activation); auto qdq_fn = BuildQDQPerChannelConvTestCase(conv_op_type, input_def, weights_def, bias_def, weight_quant_axis, strides, pads, dilations, group, auto_pad, - use_contrib_qdq); + use_contrib_qdq, output_activation); TestQDQModelAccuracy(f32_fn, qdq_fn, provider_options, opset, expected_ep_assignment, tolerance); } @@ -764,6 +808,140 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { 21); // opset } +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-tensor quantization. +TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (NO bias) -> Relu -> Q + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + // Opset 6 Clip uses attributes for min/max + OutputActivationInfo clip_info = {"Clip", {0.0f, 2.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 19, // opset + QDQTolerance(), + clip_info); + + // DQs -> Conv (NO bias) -> Clip -> Q + OutputActivationInfo clip_info_2 = {"Clip", {-6.0f, 6.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info_2); +} + +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-channel quantization. +TEST_F(QnnHTPBackendTests, ConvS8S8S32_PerChannel_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + OutputActivationInfo clip_info = {"Clip", {0.0f, 6.0f}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info); +} + // Test per-channel QDQ Conv with INT4 weights and a negative weight quantization axis that still points to dimension 0. TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { std::vector input_shape = {1, 2, 4, 4};