From ad34c67a444d15132fce1c85612f16db7d6fa74e Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 6 Nov 2023 16:28:11 -0800 Subject: [PATCH] [QNN EP] Enable Expand op (#18234) ### Description Enable Expand Op. There no directly mapping from Onnx Expand op to QNN. Need to use ElementWiseMultiply to do the data broadcast. Basically create the 2nd input with value 1.0 and use the shape data from Expand op. --- .../selectors_actions/shared/utils.cc | 1 + .../qnn/builder/op_builder_factory.cc | 4 + .../qnn/builder/op_builder_factory.h | 2 + .../qnn/builder/opbuilder/base_op_builder.h | 4 +- .../opbuilder/batch_norm_op_builder.cc | 116 ++----- .../builder/opbuilder/expand_op_builder.cc | 139 ++++++++ .../builder/opbuilder/gather_op_builder.cc | 1 - .../qnn/builder/opbuilder/pad_op_builder.cc | 42 ++- .../qnn/builder/qnn_backend_manager.cc | 6 +- .../core/providers/qnn/builder/qnn_utils.cc | 69 ++++ .../core/providers/qnn/builder/qnn_utils.h | 34 ++ .../providers/qnn/reshape_expand_op_test.cc | 313 ++++++++++++++++++ .../test/providers/qnn/reshape_op_test.cc | 225 ------------- 13 files changed, 609 insertions(+), 347 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc create mode 100644 onnxruntime/test/providers/qnn/reshape_expand_op_test.cc delete mode 100644 onnxruntime/test/providers/qnn/reshape_op_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 3f1b2f0458bc0..1a4d3a0c18151 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -30,6 +30,7 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, + {"Expand", {}}, {"Flatten", {}}, {"Transpose", {}}, {"MaxPool", {12}}, diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 17ce9b078b790..d5c3e4619f263 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -161,6 +161,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreatePadOpBuilder("Pad", *this); } + + { + CreateExpandOpBuilder("Expand", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index c2c9345e109a9..d95e2baa9457f 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -92,5 +92,7 @@ void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 0431d605bc843..75a3a6ff2ff46 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -164,7 +164,9 @@ class BaseOpBuilder : public IOpBuilder { {"LRN", QNN_OP_LRN}, - {"Pad", QNN_OP_PAD}}; + {"Pad", QNN_OP_PAD}, + + {"Expand", QNN_OP_ELEMENT_WISE_MULTIPLY}}; auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type); ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end()); return it->second; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 3e17fb157b160..8febf09f0e26d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -6,7 +6,6 @@ #include #include "core/providers/common.h" -#include "core/util/qmath.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -32,57 +31,6 @@ class BatchNormOpBuilder : public BaseOpBuilder { const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; - std::pair CheckMinMax(float rmin, float rmax) const { - // Ensure a minimum range of 0.0001 (required by QNN) - rmax = std::max(rmax, rmin + 0.0001f); - - // Both QNN and ORT require the range to include 0.0f - rmin = std::min(rmin, 0.0f); - rmax = std::max(rmax, 0.0f); - - return std::make_pair(rmin, rmax); - } - - template - Status GetQminQmax(const Qnn_DataType_t qnn_data_type, - T& qmin, - T& qmax) const { - if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else { - ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); - } - return Status::OK(); - } - - Status GetQuantParams(float rmin, - float rmax, - const Qnn_DataType_t qnn_data_type, - float& scale, - int& zero_point) const { - std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); - float qmin = 0.0f; - float qmax = 255.0f; - ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); - - scale = (rmax - rmin) / (qmax - qmin); - const float initial_zero_point = qmin - (rmin / scale); - zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); - // To match QNN quantization definition - zero_point = 0 - zero_point; - return Status::OK(); - } - inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, const uint8_t* raw_ptr, double& value, @@ -303,38 +251,6 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - inline double Dequantize(const OnnxInputInfo& info, - const double quant_value) const { - auto offset = static_cast(info.quant_param.scaleOffsetEncoding.offset); - auto scale = static_cast(info.quant_param.scaleOffsetEncoding.scale); - return (quant_value + offset) * scale; - } - - template - inline T Saturate(const T qmax, - const T qmin, - const T quant_value) const { - if (quant_value > qmax) { - return qmax; - } else if (quant_value < qmin) { - return qmin; - } else { - return quant_value; - } - } - - inline Status Quantize(const double double_value, - const float scale, - const int zero_point, - const Qnn_DataType_t qnn_data_type, - int& quant_value) const { - int qmin = 0; - int qmax = 255; - ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); - quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); - return Status::OK(); - } - Status PreprocessMean(const OnnxInputInfo& mean_info, const bool is_npu_backend, const uint8_t* mean_raw_ptr, @@ -349,7 +265,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { for (; i < static_cast(channel); ++i) { double mean_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); - mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value; + mean_out[i] = (is_npu_backend) ? utils::Dequantize(mean_info.quant_param.scaleOffsetEncoding.offset, + mean_info.quant_param.scaleOffsetEncoding.scale, + mean_value) + : mean_value; } return Status::OK(); } @@ -369,7 +288,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { for (; i < static_cast(channel); ++i) { double var_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); - std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value; + std_out[i] = (is_npu_backend) ? utils::Dequantize(var_info.quant_param.scaleOffsetEncoding.offset, + var_info.quant_param.scaleOffsetEncoding.scale, + var_value) + : var_value; std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); } return Status::OK(); @@ -392,7 +314,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { for (; i < static_cast(channel); ++i) { double scale_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); - scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value; + scale_out[i] = (is_npu_backend) ? utils::Dequantize(scale_info.quant_param.scaleOffsetEncoding.offset, + scale_info.quant_param.scaleOffsetEncoding.scale, + scale_value) + : scale_value; scale_out[i] = scale_out[i] / std_double_tensor[i]; rmax = std::max(rmax, scale_out[i]); rmin = std::min(rmin, scale_out[i]); @@ -418,7 +343,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { for (; i < static_cast(channel); ++i) { double bias_value = 0.0; ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); - bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value; + bias_out[i] = (is_npu_backend) ? utils::Dequantize(bias_info.quant_param.scaleOffsetEncoding.offset, + bias_info.quant_param.scaleOffsetEncoding.scale, + bias_value) + : bias_value; bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); rmax = std::max(rmax, bias_out[i]); rmin = std::min(rmin, bias_out[i]); @@ -437,17 +365,17 @@ class BatchNormOpBuilder : public BaseOpBuilder { raw_tensor.resize(double_tensor.size()); float scale = 0.0f; int zero_point = 0; - ORT_RETURN_IF_ERROR(GetQuantParams(static_cast(rmin), - static_cast(rmax), - info.qnn_data_type, - scale, - zero_point)); + ORT_RETURN_IF_ERROR(utils::GetQuantParams(static_cast(rmin), + static_cast(rmax), + info.qnn_data_type, + scale, + zero_point)); quant_param = QNN_QUANTIZE_PARAMS_INIT; utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); for (size_t i = 0; i < double_tensor.size(); ++i) { // onnx only supports 8 bits quantization int quant_value_int = 0; - ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); + ORT_RETURN_IF_ERROR(utils::Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { raw_tensor[i] = static_cast(quant_value_int); } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc new file mode 100644 index 0000000000000..7a44060b751cf --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/common/safeint.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace qnn { + +class ExpandOpBuilder : public BaseOpBuilder { + public: + ExpandOpBuilder() : BaseOpBuilder("ExpandOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExpandOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +template +void FillShapeInputData(std::vector& shape_data, int shape_size, T ini_value) { + shape_data.resize(shape_size * sizeof(T)); + T* shape_data_float = reinterpret_cast(shape_data.data()); + std::fill(shape_data_float, shape_data_float + shape_size, ini_value); +} + +// Use ElementWiseMultiply to implement data broadcast +// Get the shape data, and create a initializer input with value 1 and same shape +// input[0] * input[1] +Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF(inputs.size() != 2, "Expand should has 2 inputs!"); + + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + // Process shape input + const auto& input_name = inputs[1].node_arg.Name(); + bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); + ORT_RETURN_IF_NOT(is_initializer_input, "QNN doesn't support dynamic shape."); + + std::vector shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, shape), "Cannot get shape"); + uint32_t shape_rank = shape[0]; + std::vector unpacked_tensor; + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); + const int64_t* shape_data_int64 = reinterpret_cast(unpacked_tensor.data()); + std::vector input_shape(shape_rank, 0); + std::transform(shape_data_int64, shape_data_int64 + shape_rank, input_shape.begin(), + [](int64_t item) { return SafeInt(item); }); + int shape_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + + std::vector shape_data; + bool is_quantized_tensor = inputs[0].quant_param.has_value(); + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; + const auto* type_proto = inputs[0].node_arg.TypeAsProto(); + Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; + if (is_quantized_tensor) { + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(true, type_proto, qnn_data_type)); + float scale = 0.0f; + int zero_point = 0; + float rmax = 1.0f; + float rmin = 1.0f; + ORT_RETURN_IF_ERROR(utils::GetQuantParams(rmin, + rmax, + qnn_data_type, + scale, + zero_point)); + utils::InitializeQuantizeParam(quantize_param, true, scale, zero_point); + int quant_value_int = 0; + double ini_value = 1.0; + ORT_RETURN_IF_ERROR(utils::Quantize(ini_value, scale, zero_point, qnn_data_type, quant_value_int)); + switch (qnn_data_type) { + case QNN_DATATYPE_SFIXED_POINT_8: { + FillShapeInputData(shape_data, shape_size, static_cast(quant_value_int)); + break; + } + case QNN_DATATYPE_UFIXED_POINT_8: { + FillShapeInputData(shape_data, shape_size, static_cast(quant_value_int)); + break; + } + case QNN_DATATYPE_UFIXED_POINT_16: { + FillShapeInputData(shape_data, shape_size, static_cast(quant_value_int)); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported."); + } // switch + } else { + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type)); + switch (qnn_data_type) { + case QNN_DATATYPE_FLOAT_32: { + FillShapeInputData(shape_data, shape_size, static_cast(1.0)); + break; + } + case QNN_DATATYPE_INT_32: { + FillShapeInputData(shape_data, shape_size, static_cast(1)); + break; + } + case QNN_DATATYPE_UINT_32: { + FillShapeInputData(shape_data, shape_size, static_cast(1)); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported."); + } // switch + } // if-else + + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string shape_input_name(input_name + "_" + output_name); + QnnTensorWrapper input_tensorwrapper(shape_input_name, QNN_TENSOR_TYPE_STATIC, qnn_data_type, quantize_param, + std::move(input_shape), std::move(shape_data)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + + input_names.push_back(shape_input_name); + + return Status::OK(); +} + +void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index e203667576447..c441fe331df3a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -37,7 +37,6 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const { - ORT_UNUSED_PARAMETER(do_op_validation); const auto& inputs = node_unit.Inputs(); ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!"); ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index fc8c5c357682c..523095fac9aaf 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -10,6 +10,7 @@ #include "core/common/safeint.h" #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { namespace qnn { @@ -62,11 +63,6 @@ Status PadOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -template -float DequantizeValue(T value, int32_t offset, float scale) { - return static_cast(static_cast(value) - offset) * scale; -} - Status ProcessConstantValue(QnnModelWrapper& qnn_model_wrapper, std::vector& param_tensor_names, const NodeUnit& node_unit, @@ -86,43 +82,43 @@ Status ProcessConstantValue(QnnModelWrapper& qnn_model_wrapper, switch (input_info.qnn_data_type) { case QNN_DATATYPE_SFIXED_POINT_8: { auto int8_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = DequantizeValue(int8_span.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(int8_span.data()[0]))); break; } case QNN_DATATYPE_SFIXED_POINT_16: { auto int16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = DequantizeValue(int16_span.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(int16_span.data()[0]))); break; } case QNN_DATATYPE_SFIXED_POINT_32: { auto int32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = DequantizeValue(int32_span.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(int32_span.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_8: { - constant_value = DequantizeValue(unpacked_tensor.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(unpacked_tensor.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_16: { auto uint16_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = DequantizeValue(uint16_span.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(uint16_span.data()[0]))); break; } case QNN_DATATYPE_UFIXED_POINT_32: { auto uint32_span = ReinterpretAsSpan(gsl::make_span(unpacked_tensor)); - constant_value = DequantizeValue(uint32_span.data()[0], - input_info.quant_param.scaleOffsetEncoding.offset, - input_info.quant_param.scaleOffsetEncoding.scale); + constant_value = static_cast(utils::Dequantize(input_info.quant_param.scaleOffsetEncoding.offset, + input_info.quant_param.scaleOffsetEncoding.scale, + static_cast(uint32_span.data()[0]))); break; } default: diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index f8ee0f225fe46..fa859ce81be98 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -272,7 +272,7 @@ void QnnBackendManager::InitializeQnnLog() { Status QnnBackendManager::InitializeBackend() { if (true == backend_initialized_) { - LOGS_DEFAULT(INFO) << "Backend intialized already."; + LOGS_DEFAULT(INFO) << "Backend initialized already."; return Status::OK(); } @@ -312,7 +312,7 @@ bool QnnBackendManager::IsDevicePropertySupported() { Status QnnBackendManager::CreateDevice() { if (true == device_created_) { - LOGS_DEFAULT(INFO) << "Device intialized already."; + LOGS_DEFAULT(INFO) << "Device initialized already."; return Status::OK(); } @@ -797,7 +797,7 @@ Status QnnBackendManager::ExtractProfilingSubEvents(QnnProfile_EventId_t profile Status QnnBackendManager::ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id) { QnnProfile_EventData_t event_data; auto result = qnn_interface_.profileGetEventData(profile_event_id, &event_data); - ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get provile event data."); + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile event data."); LOGS(*logger_, VERBOSE) << "Profiling Event Info - Event Type: " << event_data.type << ", Event Value: " << event_data.value diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index dd202c87c0a77..e4074fa6fb60b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -423,6 +423,75 @@ bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn } } +std::pair CheckMinMax(float rmin, float rmax) { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); +} + +template +Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); +} + +Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); +} + +double Dequantize(int32_t offset, float scale, const double quant_value) { + double offset_d = static_cast(offset); + double scale_d = static_cast(scale); + return (quant_value + offset_d) * scale_d; +} + +Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index a54e0c8276e71..edbef7ae92ee0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -9,6 +9,8 @@ #include #include +#include "core/util/qmath.h" + namespace onnxruntime { namespace qnn { class QnnOpConfigWrapper; @@ -48,6 +50,38 @@ static bool ArrayHasString(const std::array& strings, std:: return false; } +std::pair CheckMinMax(float rmin, float rmax); + +template +Status GetQminQmax(const Qnn_DataType_t qnn_data_type, T& qmin, T& qmax); + +template +inline T Saturate(const T qmax, + const T qmin, + const T quant_value) { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } +} + +Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point); + +double Dequantize(int32_t offset, float scale, const double quant_value); + +Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc new file mode 100644 index 0000000000000..3964edc11461b --- /dev/null +++ b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Reshape/Expand operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeExpandTestOnCPU(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { + RunReshapeExpandTestOnCPU("Reshape", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunReshapeExpandTestOnCPU("Reshape", TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { + RunReshapeExpandTestOnCPU("Reshape", TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test Expand with non-initializer shape input, not supported. +TEST_F(QnnCPUBackendTests, Expand_NonIniShape) { + RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1.0f}), + TestInputDef({2}, false, {2, 2}), + {}, // Attributes + ExpectedEPNodeAssignment::None, + 19); // Opset +} + +// Test Expand with initializer shape input. +TEST_F(QnnCPUBackendTests, Expand_IniShape) { + RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1.0f}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test Expand with initializer shape input. +TEST_F(QnnCPUBackendTests, Expand_Uint32) { + RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test Expand with 6D output. +TEST_F(QnnCPUBackendTests, Expand_6D) { + RunReshapeExpandTestOnCPU("Expand", TestInputDef({3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({6}, true, {1, 2, 3, 4, 5, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Reshape/Expand operator. +template +GetTestQDQModelFn BuildQDQReshapeExpandTestCase(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, shape_def, attrs, + use_contrib_qdq, op_type](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // shape input + NodeArg* shape_input = MakeTestInput(builder, shape_def); + + // Reshape op + NodeArg* reshape_output = builder.MakeIntermediate(); + Node& reshape_node = builder.AddNode(op_type, {input_qdq, shape_input}, {reshape_output}); + + for (const auto& attr : attrs) { + reshape_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeExpandTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQReshapeExpandTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase(op_type, {input_def}, {shape_def}, attrs); + auto qdq_model_builder = BuildQDQReshapeExpandTestCase(op_type, input_def, shape_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test 8-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test 16-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // Opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Reshape runs on HTP backend. +TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({3}, true, {1, 1, 12}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of 0 (copy dimension from input) +TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) +TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { + RunQDQReshapeExpandTestOnHTP("Reshape", + TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test that int32 Expand runs on HTP backend. +TEST_F(QnnHTPBackendTests, Expand_HTP_int32) { + RunReshapeExpandTestOnHTP("Expand", + TestInputDef({1}, false, {1}), + TestInputDef({3}, true, {1, 2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Expand +TEST_F(QnnHTPBackendTests, Expand_4D) { + RunQDQReshapeExpandTestOnHTP("Expand", + TestInputDef({3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({4}, true, {3, 2, 2, 1}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Expand +TEST_F(QnnHTPBackendTests, Expand_5D) { + RunQDQReshapeExpandTestOnHTP("Expand", + TestInputDef({1, 3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({5}, true, {3, 2, 2, 2, 1}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Expand 6D not supported for HTP backend according to QNN doc +TEST_F(QnnHTPBackendTests, Expand_6D) { + RunQDQReshapeExpandTestOnHTP("Expand", + TestInputDef({1, 3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({6}, true, {3, 2, 2, 2, 2, 1}), + {}, // Attributes + ExpectedEPNodeAssignment::None, + 19); // Opset +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/reshape_op_test.cc b/onnxruntime/test/providers/qnn/reshape_op_test.cc deleted file mode 100644 index eb495e44ec770..0000000000000 --- a/onnxruntime/test/providers/qnn/reshape_op_test.cc +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#if !defined(ORT_MINIMAL_BUILD) - -#include - -#include "test/providers/qnn/qnn_test_utils.h" -#include "core/graph/node_attr_utils.h" - -#include "onnx/onnx_pb.h" -#include "gtest/gtest.h" - -namespace onnxruntime { -namespace test { - -// Runs a model with a Reshape operator on the QNN CPU backend. Checks the graph node assignment -// and that inference outputs for QNN EP and CPU EP match. -template -static void RunReshapeTestOnCPU(const TestInputDef& input_def, - const TestInputDef& shape_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 19) { - ProviderOptions provider_options; - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnCpu.dll"; -#else - provider_options["backend_path"] = "libQnnCpu.so"; -#endif - - RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), - provider_options, - opset, - expected_ep_assignment); -} - -// -// CPU tests: -// - -// Test that Reshape with a dynamic shape input is not supported by QNN EP. -TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { - RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, false /* is_initializer */, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset -} - -// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. -TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { - RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, true, {1, 48}), - {utils::MakeAttribute("allowzero", static_cast(1))}, - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset -} - -// Test Reshape of rank 4 -> rank 2. -TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { - RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({2}, true, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset -} - -#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// -// HTP tests: -// - -// Returns a function that creates a graph with a QDQ Reshape operator. -template -GetTestQDQModelFn BuildQDQReshapeTestCase(const TestInputDef& input_def, - const TestInputDef& shape_def, - const std::vector& attrs, - bool use_contrib_qdq = false) { - return [input_def, shape_def, attrs, - use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { - // input -> Q -> DQ -> - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, - use_contrib_qdq); - - // shape input - NodeArg* shape_input = MakeTestInput(builder, shape_def); - - // Reshape op - NodeArg* reshape_output = builder.MakeIntermediate(); - Node& reshape_node = builder.AddNode("Reshape", {input_qdq, shape_input}, {reshape_output}); - - for (const auto& attr : attrs) { - reshape_node.AddAttributeProto(attr); - } - - // op_output -> Q -> DQ -> output - // NOTE: Input and output quantization parameters must be equal for Reshape. - output_qparams[0] = input_qparams; // Overwrite! - AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); - }; -} - -// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment -// and that inference outputs for QNN EP and CPU EP match. -template -static void RunReshapeTestOnHTP(const TestInputDef& input_def, - const TestInputDef& shape_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 19) { - ProviderOptions provider_options; - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), - provider_options, - opset, - expected_ep_assignment); -} - -// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference -// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). -template -static void RunQDQReshapeTestOnHTP(const TestInputDef& input_def, - const TestInputDef& shape_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 19, - bool use_contrib_qdq = false) { - ProviderOptions provider_options; - -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - auto f32_model_builder = BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs); - auto qdq_model_builder = BuildQDQReshapeTestCase(input_def, shape_def, attrs, use_contrib_qdq); - TestQDQModelAccuracy(f32_model_builder, - qdq_model_builder, - provider_options, - opset, - expected_ep_assignment); -} - -// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. -TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, false /* is_initializer */, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset -} - -// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. -TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, true, {1, 48}), - {utils::MakeAttribute("allowzero", static_cast(1))}, - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset -} - -// Test 8-bit QDQ Reshape of rank 4 -> rank 2. -TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({2}, true, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset -} - -// Test 16-bit QDQ Reshape of rank 4 -> rank 2. -TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({2}, true, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19, // Opset - true); // Use com.microsoft Q/DQ ops -} - -// Test that int32 Reshape runs on HTP backend. -TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { - std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - RunReshapeTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), - TestInputDef({3}, true, {1, 1, 12}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset -} - -// Test QDQ Reshape with a shape value of 0 (copy dimension from input) -TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset -} - -// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) -TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { - RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset -} - -#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -} // namespace test -} // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD)