From c65e892089e2e6f383e80339334450ac06be5ba0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 20 Sep 2023 10:35:15 -0700 Subject: [PATCH 1/2] [CUDA] Fix performance bug in DecoderMaskedMultiheadAttention for BeamSearch (#17613) --- ...decoder_masked_multihead_attention_impl.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c8877a5e3f872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = add_vec(q, q_bias); } - T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; - if (has_beams) { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } - } - } else { + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } From c55da45e20435b8aa9edb78179b6027502b778b0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 20 Sep 2023 14:31:01 -0700 Subject: [PATCH 2/2] [QNN EP] Add more op unit tests (fix Clip, TopK, Tile) (#17457) ### Description Adds more operator unit tests (all op types should now have at least 1 unit test): - [x] Reshape - [x] Flatten - [x] Squeeze - [x] Unsqueeze - [x] Gemm - [x] Clip - Enable QDQ Clip on HTP backend (when not optimized away by L1 ClipQuantFusion optimizer) - Add support for 16-bit QDQ Clip to ClipQuantFusion optimizer - [x] Split - [x] Topk - Enable QDQ TopK on HTP backend - [x] Tile - Enable QDQ Tile on HTP backend ### Motivation and Context Increase QNN operator support and test coverage. --- .../qdq_transformer/clip_quantizelinear.cc | 25 +- .../selectors_actions/qdq_selectors.cc | 36 ++ .../selectors_actions/qdq_selectors.h | 8 + .../selectors_actions/shared/utils.cc | 18 +- .../qnn/builder/opbuilder/clip_op_builder.cc | 127 +++--- .../providers/qnn/builder/opbuilder/topk.cc | 15 +- .../test/optimizer/qdq_transformer_test.cc | 7 + .../test/providers/qnn/average_pool_test.cc | 6 +- .../test/providers/qnn/clip_op_test.cc | 188 +++++++++ .../test/providers/qnn/flatten_op_test.cc | 202 +++++++++ .../test/providers/qnn/gather_op_htp_test.cc | 64 +-- .../test/providers/qnn/gemm_op_test.cc | 341 +++++++++++++++ .../providers/qnn/instance_norm_htp_test.cc | 21 +- .../test/providers/qnn/layer_norm_test.cc | 4 +- .../providers/qnn/leakyrelu_op_htp_test.cc | 46 +-- .../test/providers/qnn/max_min_op_test.cc | 9 +- .../test/providers/qnn/pool_op_test.cpp | 19 +- .../test/providers/qnn/qnn_test_utils.cc | 14 +- .../test/providers/qnn/qnn_test_utils.h | 82 ++-- .../test/providers/qnn/reshape_op_test.cc | 225 ++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 16 +- .../test/providers/qnn/slice_htp_test.cc | 64 +-- .../test/providers/qnn/split_op_test.cc | 387 ++++++++++++++++++ .../qnn/squeeze_unsqueeze_op_test.cc | 324 +++++++++++++++ .../test/providers/qnn/tile_op_test.cc | 132 ++++++ .../test/providers/qnn/topk_op_test.cc | 209 ++++++++++ 26 files changed, 2273 insertions(+), 316 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/clip_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/flatten_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/gemm_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/reshape_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/split_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/tile_op_test.cc create mode 100644 onnxruntime/test/providers/qnn/topk_op_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a0942c31b0161..50653b368857d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" + +#include + +#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" @@ -50,14 +53,26 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& switch (zp_initializer.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT8: { const int8_t zero_point = zp_initializer.data()[0]; - lower = scale * (-128 - zero_point); - upper = scale * (127 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { const uint8_t zero_point = zp_initializer.data()[0]; - lower = scale * (0 - zero_point); - upper = scale * (255 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + const int16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + const uint16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } default: diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 16c7bd5fce960..5015e48fdb7b8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -496,6 +496,42 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, return dt_input_1 == dt_input_2; } +bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + constexpr int num_dq_inputs = 1; + constexpr int num_q_outputs = 1; + if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { + return false; + } + + if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); + !dq_validation_status.IsOK()) { + return false; + } + + if (num_q_outputs != gsl::narrow_cast(q_nodes.size())) { + return false; + } + + const Node& dq_node = *dq_nodes.front(); + const Node& q_node = *q_nodes.front(); + + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d8fefdd8dc3d9..be7f7e0288eda 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -220,6 +220,14 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// TopK has 1 DQ input node and 1 Q output node. +// Zero point and scale are constant scalars and must match +class TopKNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index f1bdd7a99c329..3f1b2f0458bc0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -36,7 +36,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Resize", {}}, {"Split", {}}, {"Squeeze", {}}, - {"Unsqueeze", {}}}; + {"Unsqueeze", {}}, + {"Tile", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { @@ -78,7 +79,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Abs", {}}, {"Neg", {}}, {"DepthToSpace", {}}, - {"SpaceToDepth", {}}}; + {"SpaceToDepth", {}}, + {"Clip", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -127,6 +129,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { return {{"Pad", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { + return {{"TopK", {}}}; +} + /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { /* register selectors for miscellaneous ops */ @@ -227,6 +233,13 @@ void RegisterPadSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterTopKSelector(Selectors& qdq_selectors) { + /* register selector for TopK op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetTopKOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -242,6 +255,7 @@ void SelectorManager::CreateSelectors() { RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); + RegisterTopKSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 92a7feea7fc54..df4c718949269 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -9,8 +12,6 @@ #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class ClipOpBuilder : public BaseOpBuilder { @@ -33,8 +34,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; - mutable float min_value_ = std::numeric_limits::lowest(); - mutable float max_value_ = std::numeric_limits::max(); }; Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { @@ -61,61 +60,8 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (do_op_validation) { ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit)); } - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - - auto inputs = node_unit.Inputs(); - for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); - - auto& input_name = inputs[input_i].node_arg.Name(); - if (input_name.empty()) { - // Ignore unspecified/unused optional input - continue; - } - if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { - LOGS(logger, VERBOSE) << "Tensor already added or the input is not named, skip it: " << input_name; - input_names.push_back(input_name); - continue; - } - - const auto* type_proto = inputs[input_i].node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape"); - - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - - float* ini_data = nullptr; - std::vector unpacked_tensor; - bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); - if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); - ini_data = reinterpret_cast(unpacked_tensor.data()); - if (input_i == 1) { - min_value_ = *ini_data; - continue; - } else if (input_i == 2) { - max_value_ = *ini_data; - continue; - } - } - ORT_ENFORCE(input_i == 0, "QNN ReluMinMax operator expects only one input. Min and max are expected to be parameters, ie. initializer inputs in ONNX model"); - - Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, - std::move(input_shape), std::move(unpacked_tensor)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(input_name); - } - return Status::OK(); + return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names); } Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -123,20 +69,59 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + const size_t num_inputs = inputs.size(); + + const Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; std::vector param_tensor_names; - Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; - min_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - min_qnn_scalar.floatValue = min_value_; - QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, min_qnn_scalar); - param_tensor_names.push_back(min_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); - - Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; - max_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - max_qnn_scalar.floatValue = max_value_; - QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, max_qnn_scalar); - param_tensor_names.push_back(max_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + + auto get_f32_from_bytes = [](const std::vector& bytes, float default_val) -> float { + return bytes.empty() ? default_val : *reinterpret_cast(bytes.data()); + }; + + // Set the 'min' parameter. + { + std::vector min_val_bytes; + + if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) { + OnnxInputInfo min_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], min_input_info)); + ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'min' input of the Clip operator must be of type float32."); + assert(min_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*min_input_info.initializer_tensor, min_val_bytes)); + } + + Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; + min_qnn_scalar.dataType = qnn_data_type; + min_qnn_scalar.floatValue = get_f32_from_bytes(min_val_bytes, std::numeric_limits::lowest()); + QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, + min_qnn_scalar); + param_tensor_names.push_back(min_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); + } + + // Set the 'max' parameter. + { + std::vector max_val_bytes; + + if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) { + OnnxInputInfo max_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], max_input_info)); + ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'max' input of the Clip operator must of type float32."); + assert(max_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*max_input_info.initializer_tensor, max_val_bytes)); + } + + Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; + max_qnn_scalar.dataType = qnn_data_type; + max_qnn_scalar.floatValue = get_f32_from_bytes(max_val_bytes, std::numeric_limits::max()); + QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, + max_qnn_scalar); + param_tensor_names.push_back(max_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + } ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 6ca36736f2f7f..047972294f78c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -63,9 +63,20 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N auto rank = input_shape.size(); auto axis = node_helper.Get("axis", -1); - if (-1 == axis && axis != static_cast(rank - 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK axis is always the last dimension"); + ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), + "QNN TopK's axis is always the last dimension"); + + // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. + // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type + // for TopK ops within the graph. However, if the TopK op **generates** a graph output, + // then we cannot support it on the HTP backend. + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu_backend) { + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), + "QNN EP does not support TopK ops that generate a graph output."); } + return Status::OK(); } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index a438a61cb9b36..d3616a14d8a5d 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2497,10 +2497,15 @@ TEST(QDQTransformerTests, Clip) { epsilon); }; + constexpr int16_t int16_min = std::numeric_limits::min(); + constexpr uint16_t uint16_min = std::numeric_limits::min(); + std::vector opsets{12, 18, 19}; for (auto opset : opsets) { test_case(.0235294122248888f, static_cast(-128), 0, opset); // [0, 6] test_case(.0235294122248888f, static_cast(-128), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, int16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, int16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(-128), 0, opset); // [0, 5.1] test_case(.02f, static_cast(-128), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(-128), 1, opset); // [0, 7.65] @@ -2513,6 +2518,8 @@ TEST(QDQTransformerTests, Clip) { test_case(.04f, static_cast(-97), 1, opset, true); // [-1.24, 8.96] contrib qdq test_case(.02352941176f, static_cast(0), 0, opset); // [0, 6] test_case(.02352941176f, static_cast(0), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, uint16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, uint16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 79ec07796c0e8..0ee52f7fec21a 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -32,7 +32,7 @@ static void RunAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -53,8 +53,8 @@ static void RunQDQAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs), - BuildQDQOpTestCase(op_type, input_defs, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc new file mode 100644 index 0000000000000..15ba3b5de2fa1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Clip operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunClipTestOnCPU(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Clip with a dynamic min or max input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Clip_Dynamic_MinMax_Unsupported) { + // Dynamic min input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, false /* is_initializer */, {-5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + // Dynamic max input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, false, {5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Clip with default min/max. +TEST_F(QnnCPUBackendTests, Clip_4D_f32_DefaultMinMax) { + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test Clip with 5D input. +TEST_F(QnnCPUBackendTests, Clip_5D_f32) { + RunClipTestOnCPU(TestInputDef({1, 1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a QDQ Clip model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQClipTestOnHTP(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Clip", {input_def}, {min_max_defs}, {}); + auto qdq_model_builder = BuildQDQOpTestCase("Clip", {input_def}, {min_max_defs}, {}, + kOnnxDomain, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U8_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U16_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U16_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Clip of rank 5. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Clip -> Q + // QDQ node group, which gets lowered to a single QNN Clip node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 1, 2, 2, 2}, {0, 1, 6, 10, 20, 100, 128, 255}); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Min/Max initializers + NodeArg* min_input = builder.MakeScalarInitializer(5.0f); + NodeArg* max_input = builder.MakeScalarInitializer(100.0f); + + // Clip -> + NodeArg* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {input_dq, min_input, max_input}, {clip_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(clip_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/flatten_op_test.cc b/onnxruntime/test/providers/qnn/flatten_op_test.cc new file mode 100644 index 0000000000000..637d3257ddea7 --- /dev/null +++ b/onnxruntime/test/providers/qnn/flatten_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Flatten operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnCPU(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Flatten input (rank4) with axis == 0. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_Axis0) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank4) with axis == -1. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_AxisNeg1) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank5) with axis == 2. +TEST_F(QnnCPUBackendTests, Flatten_Rank5_Axis2) { + RunFlattenTestOnCPU(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a model with a non-QDQ Flatten operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Flatten model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Flatten", {input_def}, {}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Flatten", {input_def}, {}, attrs, kOnnxDomain, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten with an input of rank5. +TEST_F(QnnHTPBackendTests, Flatten_QDQ8bit_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Flatten -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 2, 3, 4, 5}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Flatten -> + NodeArg* flatten_output = builder.MakeIntermediate(); + Node& flatten_node = builder.AddNode("Flatten", {input_dq}, {flatten_output}); + flatten_node.AddAttribute("axis", static_cast(2)); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(flatten_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test that int32 non-QDQ Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank4_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +// Test that rank 5 int32 Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank5_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 5b05b39f34a27..37e0db906d054 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -14,47 +15,14 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float model with a Gather op. -template -static GetTestModelFn BuildGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* indices = MakeTestInput(builder, indices_def); - NodeArg* output = builder.MakeOutput(); - - Node& gather_node = builder.AddNode("Gather", {input, indices}, {output}); - gather_node.AddAttribute("axis", axis); - }; -} - -// Function that builds a QDQ model with a Gather op. -template -static GetTestQDQModelFn BuildQDQGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - NodeArg* indices = MakeTestInput(builder, indices_def); - - NodeArg* gather_output = builder.MakeIntermediate(); - Node& gather_node = builder.AddNode("Gather", {input_qdq, indices}, {gather_output}); - gather_node.AddAttribute("axis", axis); - - AddQDQNodePairWithOutputAsGraphOutput(builder, gather_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - // Test the accuracy of a QDQ Gather model on QNN EP. Checks if the QDQ model on QNN EP as accurate as the QDQ model on CPU EP // (compared to float32 model). template -static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestInputDef& indices_def, - int64_t axis, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { +static void RunQDQGatherOpTest(const TestInputDef& input_def, + const TestInputDef& indices_def, + const std::vector& attrs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -62,12 +30,14 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildGatherOpTestCase(input_def, indices_def, axis), - BuildQDQGatherOpTestCase(input_def, indices_def, axis), + auto f32_model_builder = BuildOpTestCase("Gather", {input_def}, {indices_def}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Gather", {input_def}, {indices_def}, attrs); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -77,7 +47,7 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -86,7 +56,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::None); } @@ -98,7 +68,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -110,7 +80,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -122,7 +92,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), - 1, + {utils::MakeAttribute("axis", static_cast(1))}, 13, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc new file mode 100644 index 0000000000000..15f26717b06fd --- /dev/null +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Gemm operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunGemmTestOnCPU(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Gemm", input_defs, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Gemm with non-default 'alpha' or 'beta' attributes is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { + // Check that alpha != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // Check that beta != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); + + // 2D matrix mul with bias not supported. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data)}, + {}, + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. +} + +// Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). +TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that builds a model with a QDQ Gemm node. +template +inline GetTestQDQModelFn BuildQDQGemmTestCase(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + assert(num_inputs == 2 || num_inputs == 3); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Process bias + if (num_inputs == 3) { + NodeArg* bias_input = MakeTestQDQBiasInput(builder, input_defs[2], input0_qparams.scale * input1_qparams.scale, + use_contrib_qdq); + op_inputs.push_back(bias_input); + } + + // Op -> op_output + auto* gemm_output = builder.MakeIntermediate(); + Node& gemm_node = builder.AddNode("Gemm", op_inputs, {gemm_output}); + + for (const auto& attr : attrs) { + gemm_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, gemm_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Gemm model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQGemmTestOnHTP(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + float f32_abs_err = 1e-4f, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + auto f32_model_builder = BuildOpTestCase("Gemm", input_defs, {}, attrs); + auto qdq_model_builder = BuildQDQGemmTestCase(input_defs, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment, + f32_abs_err); +} + +// Test 8-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 0 (err 120.73912048339844) +// CPU QDQ val: 120.73889923095703 (err 0.00022125244140625) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 1e-4f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm (16bit act, 8bit weight) with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 120.48043823242188 (err 0.2586822509765625) +// CPU QDQ val: 120.48980712890625 (err 0.2493133544921875) +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with dynamic A and B inputs. The Bias is static. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.48132994771003723, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 77.012794494628906 (err 43.726325988769531) +// CPU QDQ val: 119.85115814208984 (err 0.88796234130859375) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_B_Static_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), // Dynamic => inaccuracy + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm with static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm (16bit activation, 8bit weight) with transposed A/B and static B and Bias inputs. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.00047966410056687891, zero_point=0. +// Expected val: 29.434776306152344 +// QNN QDQ val: 29.191877365112305 (err 0.24289894104003906) +// CPU QDQ val: 29.197153091430664 (err 0.23762321472167969) +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 594973e37ef0b..f662ac14336f8 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -16,25 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float32 model with an InstanceNormalization operator. -GetTestModelFn BuildInstanceNormTestCase(const TestInputDef& input_def, - const TestInputDef& scale_def, - const TestInputDef& bias_def, - const std::vector& attrs) { - return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* scale = MakeTestInput(builder, scale_def); - NodeArg* bias = MakeTestInput(builder, bias_def); - - NodeArg* output = builder.MakeOutput(); - Node& op_node = builder.AddNode("InstanceNormalization", {input, scale, bias}, {output}); - - for (const auto& attr : attrs) { - op_node.AddAttributeProto(attr); - } - }; -} - // Function that builds a QDQ model with an InstanceNormalization operator. template static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, @@ -93,7 +74,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, #endif // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildInstanceNormTestCase(input_def, scale_def, bias_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs), BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs), provider_options, 18, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index aa6c6a142e6d1..085454004e5a5 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -29,7 +29,7 @@ static void RunLayerNormCpuTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), provider_options, 17, expected_ep_assignment); @@ -114,7 +114,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), BuildQDQLayerNormTestCase(input_def, scale_def, attrs), provider_options, 17, // opset diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index a8237817c71df..e3077ec569923 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -15,42 +16,10 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Creates a function that builds a model with a LeakyRelu operator. -static GetTestModelFn BuildLeakyReluOpTestCase(const TestInputDef& input_def, float alpha) { - return [input_def, alpha](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input}, {output}); - leakyrelu_node.AddAttribute("alpha", alpha); - }; -} - -// Creates a function that builds a QDQ model with a LeakyRelu operator. -template -static GetTestQDQModelFn BuildQDQLeakyReluOpTestCase(const TestInputDef& input_def, - float alpha) { - return [input_def, alpha](ModelTestBuilder& builder, - std::vector>& output_qparams) { - // input => Q => DQ => - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - // LeakryRelu - auto* leakyrelu_output = builder.MakeIntermediate(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input_qdq}, {leakyrelu_output}); - leakyrelu_node.AddAttribute("alpha", alpha); - - // => Q => DQ -> final output - AddQDQNodePairWithOutputAsGraphOutput(builder, leakyrelu_output, output_qparams[0].scale, - output_qparams[0].zero_point); - }; -} - // Checks the accuracy of a QDQ LeakyRelu model by comparing to ORT CPU EP. template static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, - float alpha, + const std::vector& attrs, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; @@ -60,12 +29,11 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildLeakyReluOpTestCase(input_def, alpha), - BuildQDQLeakyReluOpTestCase(input_def, alpha), + TestQDQModelAccuracy(BuildOpTestCase("LeakyRelu", {input_def}, {}, attrs), + BuildQDQOpTestCase("LeakyRelu", {input_def}, {}, attrs), provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -74,7 +42,7 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 15, ExpectedEPNodeAssignment::All); } @@ -85,7 +53,7 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 16, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 09ea71e5f03eb..3deff121f3c72 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -27,7 +27,7 @@ static void RunCPUMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), provider_options, opset, expected_ep_assignment); @@ -48,12 +48,11 @@ static void RunQDQMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, kOnnxDomain), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model provider_options, opset, - expected_ep_assignment, - 1e-4f); + expected_ep_assignment); } // diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index fee10a542fb82..7ed9072a95b32 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -17,21 +17,6 @@ namespace onnxruntime { namespace test { -// Returns a function that creates a graph with a single MaxPool operator. -static GetTestModelFn BuildPoolTestCase(const std::string& op_type, - const TestInputDef& input_def, - const std::vector& attrs) { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& pool_node = builder.AddNode(op_type, {input}, {output}); - - for (const auto& attr : attrs) { - pool_node.AddAttributeProto(attr); - } - }; -} - // Returns a function that creates a graph with a QDQ MaxPool operator. template GetTestQDQModelFn BuildPoolQDQTestCase(const std::string& op_type, @@ -74,7 +59,7 @@ static void RunPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildPoolTestCase(op_type, input_def, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -95,7 +80,7 @@ static void RunQDQPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildPoolTestCase(op_type, input_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, attrs), BuildPoolQDQTestCase(op_type, input_def, attrs), provider_options, opset, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 724e9a11cd781..51df93f8853ec 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -73,7 +73,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals) { + std::vector& output_vals) { SessionOptions so; so.session_logid = log_id; RunOptions run_options; @@ -102,14 +102,12 @@ void InferenceModel(const std::string& model_data, const char* log_id, } const auto& outputs = graph.GetOutputs(); + std::vector output_names; - // fetch all outputs if necessary. - if (output_names.empty()) { - output_names.reserve(outputs.size()); - for (const auto* node_arg : outputs) { - if (node_arg->Exists()) { - output_names.push_back(node_arg->Name()); - } + output_names.reserve(outputs.size()); + for (const auto* node_arg : outputs) { + if (node_arg->Exists()) { + output_names.push_back(node_arg->Name()); } } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index fd572fa17f2b1..14c62f98f6a3e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -213,13 +213,12 @@ inline QuantParams GetTestInputQuantParams(const TestInputDef& inp * \param execution_provider The EP on which to run the model. Set to nullptr for CPU EP. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. * \param feeds The input feeds. - * \param output_names If empty, the function will write the output names. * \param output_vals Initialized to the inference results. */ void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals); + std::vector& output_vals); /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: @@ -263,9 +262,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; - std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, - f32_helper.feeds_, output_names, cpu_f32_outputs); + f32_helper.feeds_, cpu_f32_outputs); ASSERT_FALSE(cpu_f32_outputs.empty()); const size_t num_outputs = cpu_f32_outputs.size(); @@ -304,13 +302,13 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, output_names, qnn_qdq_outputs); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. std::vector cpu_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", nullptr, ExpectedEPNodeAssignment::All, - qdq_helper.feeds_, output_names, cpu_qdq_outputs); + qdq_helper.feeds_, cpu_qdq_outputs); ASSERT_EQ(cpu_qdq_outputs.size(), num_outputs); ASSERT_EQ(qnn_qdq_outputs.size(), num_outputs); @@ -320,7 +318,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Compare accuracy of QDQ results with float model. // QNN EP must be at least as accurate as CPU EP when running the QDQ model. + const std::string base_output_name = "output_"; for (size_t i = 0; i < num_outputs; i++) { + std::string debug_output_name = base_output_name + std::to_string(i); auto& cpu_qdq_tensor = cpu_qdq_outputs[i].Get(); auto& qnn_qdq_tensor = qnn_qdq_outputs[i].Get(); @@ -353,8 +353,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe } EXPECT_TRUE(is_as_accurate_as_cpu_qdq) - << "Inaccuracy detected for output '" - << output_names[i] + << "Inaccuracy detected for output '" << debug_output_name << "', element " << j << ".\nOutput quant params: scale=" << output_qparams[i].scale << ", zero_point=" << static_cast(output_qparams[i].zero_point) @@ -363,7 +362,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe << "CPU QDQ val: " << cpu_qdq_val << " (err " << cpu_err << ")"; } } else { - VerifyOutput(output_names[i], cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); + VerifyOutput(debug_output_name, cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); } } } @@ -438,25 +437,33 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef +template inline GetTestModelFn BuildOpTestCase(const std::string& op_type, - const std::vector>& input_defs, + const std::vector>& input_defs_1, + const std::vector>& input_defs_2, const std::vector& attrs, const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); + + for (const auto& input_def : input_defs_1) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } - for (const auto& input_def : input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + for (const auto& input_def : input_defs_2) { + NodeArg* input = MakeTestInput(builder, input_def); op_inputs.push_back(input); } @@ -470,7 +477,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, } /** - * Returns a function that builds a model with a single QDQ operator with N inputs of the same element type. + * Returns a function that builds a model with a single QDQ operator with N float (quantizeable) inputs + * and M inputs of a potentially different type. * * \param op_type The operator to instantiate. * \param input_defs List of input definitions. @@ -478,25 +486,33 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). * \returns A model building function. */ -template -inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, - const std::vector>& input_defs, - const std::vector& attrs, - const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { - return [op_type, input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { +template +inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, + const std::vector>& quant_input_defs, + const std::vector>& non_quant_input_defs, + const std::vector& attrs, + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); - for (const auto& input_def : input_defs) { + // Create QDQ inputs + for (const auto& input_def : quant_input_defs) { NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } + // Create non-QDQ inputs + for (const auto& input_def : non_quant_input_defs) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } + // Op -> op_output auto* op_output = builder.MakeIntermediate(); Node& onnx_node = builder.AddNode(op_type, op_inputs, {op_output}, op_domain); @@ -506,8 +522,8 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty } // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } diff --git a/onnxruntime/test/providers/qnn/reshape_op_test.cc b/onnxruntime/test/providers/qnn/reshape_op_test.cc new file mode 100644 index 0000000000000..eb495e44ec770 --- /dev/null +++ b/onnxruntime/test/providers/qnn/reshape_op_test.cc @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Reshape operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnCPU(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Reshape operator. +template +GetTestQDQModelFn BuildQDQReshapeTestCase(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, shape_def, attrs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // shape input + NodeArg* shape_input = MakeTestInput(builder, shape_def); + + // Reshape op + NodeArg* reshape_output = builder.MakeIntermediate(); + Node& reshape_node = builder.AddNode("Reshape", {input_qdq, shape_input}, {reshape_output}); + + for (const auto& attr : attrs) { + reshape_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs); + auto qdq_model_builder = BuildQDQReshapeTestCase(input_def, shape_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test 8-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test 16-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // Opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Reshape runs on HTP backend. +TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunReshapeTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({3}, true, {1, 1, 12}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of 0 (copy dimension from input) +TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) +TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 63498982930f5..f77c098f72116 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -32,7 +32,7 @@ static void RunOpTestOnCPU(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -113,8 +113,8 @@ static void RunQDQOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -137,7 +137,7 @@ static void RunOpTest(const std::string& op_type, #endif // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -698,8 +698,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); @@ -708,8 +708,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run will load and run from Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); diff --git a/onnxruntime/test/providers/qnn/slice_htp_test.cc b/onnxruntime/test/providers/qnn/slice_htp_test.cc index f7163f04736a5..edc079dc65276 100644 --- a/onnxruntime/test/providers/qnn/slice_htp_test.cc +++ b/onnxruntime/test/providers/qnn/slice_htp_test.cc @@ -16,51 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a model with a Slice operator. -template -GetTestModelFn BuildSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder) { - NodeArg* data = MakeTestInput(builder, data_def); - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - NodeArg* output = builder.MakeOutput(); - builder.AddNode("Slice", {data, starts, ends, axes, steps}, {output}); - }; -} - -// Function that builds a QDQ model with a Slice operator. -template -static GetTestQDQModelFn BuildQDQSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* data = MakeTestInput(builder, data_def); - QuantParams data_qparams = GetTestInputQuantParams(data_def); - NodeArg* data_qdq = AddQDQNodePair(builder, data, data_qparams.scale, data_qparams.zero_point); - - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - auto* slice_output = builder.MakeIntermediate(); - builder.AddNode("Slice", {data_qdq, starts, ends, axes, steps}, {slice_output}); - - // Add output -> Q -> output_u8 - AddQDQNodePairWithOutputAsGraphOutput(builder, slice_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - /** * Runs an Slice model on the QNN HTP backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. @@ -86,13 +41,14 @@ static void RunSliceQDQTest(const TestInputDef& data_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Runs model with DQ-> Slice -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), - BuildQDQSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + const std::vector> f32_inputs = {data_def}; + const std::vector> int64_inputs = {starts_def, ends_def, axes_def, steps_def}; + + TestQDQModelAccuracy(BuildOpTestCase("Slice", f32_inputs, int64_inputs, {}), + BuildQDQOpTestCase("Slice", f32_inputs, int64_inputs, {}), provider_options, 18, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } /** @@ -119,12 +75,12 @@ static void RunSliceNonQDQOnHTP(const TestInputDef& data_def, #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - - RunQnnModelTest(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + auto f32_model_builder = BuildOpTestCase("Slice", {data_def}, + {starts_def, ends_def, axes_def, steps_def}, {}); + RunQnnModelTest(f32_model_builder, provider_options, 13, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Check that QNN compiles DQ -> Slice -> Q as a single unit. diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc new file mode 100644 index 0000000000000..57e4b211777bb --- /dev/null +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +GetTestModelFn BuildSplitTestCase(const TestInputDef& input_def, + const std::vector& split, bool split_is_input, + int64_t axis, int64_t num_outputs) { + return [input_def, split, split_is_input, axis, num_outputs](ModelTestBuilder& builder) { + std::vector op_inputs; + + op_inputs.push_back(MakeTestInput(builder, input_def)); + + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeOutput()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + }; +} + +template +static void RunSplitOpTestOnCPU(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test Split opset 18 on CPU backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Return function that builds a model with a QDQ Split. +template +GetTestQDQModelFn BuildQDQSplitTestCase(const TestInputDef& input_def, + const std::vector& split, + bool split_is_input, + int64_t axis, + int64_t num_outputs, + bool use_contrib_qdq = false) { + return [input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector op_inputs; + + // Add QDQ input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input_after_qdq); + + // Add split input + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeIntermediate()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + + // op_output -> Q -> DQ -> output + assert(output_qparams.size() == actual_num_outputs); + for (size_t i = 0; i < actual_num_outputs; i++) { + // NOTE: Input and output quantization parameters must be equal for Split. + output_qparams[i] = input_qparams; + AddQDQNodePairWithOutputAsGraphOutput(builder, split_outputs[i], output_qparams[i].scale, + output_qparams[i].zero_point, use_contrib_qdq); + } + }; +} + +// Runs a non-QDQ Split operator on the HTP backend. +template +static void RunSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Split operator on the HTP backend. +template +static void RunQDQSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + auto f32_model_builder = BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs); + auto qdq_model_builder = BuildQDQSplitTestCase(input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that HTP can run non-QDQ Split (int32 input). +TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { + // Equal split. + RunSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18_U16) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 13. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 11. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc new file mode 100644 index 0000000000000..33d2f64c0315e --- /dev/null +++ b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Squeeze (or Unsqueeze) operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnCPU(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnCPUBackendTests, Squeeze_Rank5_Rank2_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 1, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Squeeze axes 0 and 2 => (3, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnCPUBackendTests, Squeeze_Rank4_Rank3_NegAxes_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank5_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({3, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Add 1's => (1, 3, 1, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ (Un)Squeeze operator. +template +GetTestQDQModelFn BuildQDQSqueezeTestCase(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + bool use_contrib_qdq = false) { + return [op_type, input_def, axes_def, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // axes input + NodeArg* axes_input = MakeTestInput(builder, axes_def); + + // (Un)Squeeze op + NodeArg* op_output = builder.MakeIntermediate(); + builder.AddNode(op_type, {input_qdq, axes_input}, {op_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for (Un)Squeeze. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ (Un)Squeeze operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnHTP(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ (Un)Squeeze model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and +// that inference running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP +// (when compared to the baseline float32 model). +template +static void RunQDQSqueezeTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase(op_type, {input_def}, {axes_def}, {}); + auto qdq_model_builder = BuildQDQSqueezeTestCase(op_type, input_def, axes_def, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnHTPBackendTests, Squeeze_Rank5_Rank2_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Squeeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 3, 1, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Squeeze axes 0 and 2 => (3, 2, 4) + + // Squeeze -> + NodeArg* squeeze_output = builder.MakeIntermediate(); + builder.AddNode("Squeeze", {input_dq, axes_input}, {squeeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(squeeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank5_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Unsqueeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({3, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Add 1's => (1, 3, 1, 2, 4) + + // Unsqueeze -> + NodeArg* unsqueeze_output = builder.MakeIntermediate(); + builder.AddNode("Unsqueeze", {input_dq, axes_input}, {unsqueeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(unsqueeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Squeeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Squeeze_Int32_Rank4_Rank3) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Squeeze 0th axis => (3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +// Test that int32 Unsqueeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Unsqueeze_Int32_Rank3_Rank4) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Unsqueeze", + TestInputDef({3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Unsqueeze 0th axis => (1, 3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/tile_op_test.cc b/onnxruntime/test/providers/qnn/tile_op_test.cc new file mode 100644 index 0000000000000..2b35c730ee5fe --- /dev/null +++ b/onnxruntime/test/providers/qnn/tile_op_test.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Tile operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTileTestOnCPU(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Test that Tile with a dynamic repeats input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Tile_DynamicRepeats_Unsupported) { + RunTileTestOnCPU(TestInputDef({2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), + TestInputDef({2}, false /* is_initializer */, {1, 2}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Tile with rank 4 float input. +TEST_F(QnnCPUBackendTests, Tile_F32_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunTileTestOnCPU(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Tile operator. +template +GetTestQDQModelFn BuildQDQTileTestCase(const TestInputDef& input_def, + const TestInputDef& repeats_def, + bool use_contrib_qdq = false) { + return [input_def, repeats_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // repeats input + NodeArg* repeats_input = MakeTestInput(builder, repeats_def); + + // Tile op + NodeArg* tile_output = builder.MakeIntermediate(); + builder.AddNode("Tile", {input_qdq, repeats_input}, {tile_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Tile. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, tile_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Tile model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTileTestOnHTP(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}); + auto qdq_model_builder = BuildQDQTileTestCase(input_def, repeats_def, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U8_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U16_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc new file mode 100644 index 0000000000000..93e725af5f20e --- /dev/null +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that builds a model with a TopK operator. +template +inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool cast_output_indices = true) { + return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* k_input = MakeTestInput(builder, k_def); + + NodeArg* values_output = builder.MakeOutput(); + NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // Cast indices to uint32 + if (cast_output_indices) { + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + } + }; +} + +// Runs a model with a TopK operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTopKTestOnCPU(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that TopK with a dynamic K input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, false /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK that returns the top k minimum values is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_MinValues_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("largest", static_cast(0))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test TopK on CPU backend: top 2 largest floats from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestFloats_LastAxis) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test TopK on CPU backend: top 2 largest int32s from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestInt32s_LastAxis) { + std::vector input_data = {-6, -5, -4, -3, -2, 0, 1, 2, 3, 4, 5, 6}; + RunTopKTestOnCPU(TestInputDef({1, 2, 2, 3}, false, input_data), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ TopK operator. +template +GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, k_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // K input + NodeArg* k_input = MakeTestInput(builder, k_def); + + // Reshape op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // Cast indices to uint32 (HTP backend does not support int64 graph outputs) + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +// Runs a QDQ TopK model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +// TODO: Inaccuracy detected for output 'output_0', element 6. +// Output quant params: scale=0.00061036087572574615, zero_point=32768. +// Expected val: -7.2340402603149414 +// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) +// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) +TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD)