diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 16c7bd5fce960..5015e48fdb7b8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -496,6 +496,42 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, return dt_input_1 == dt_input_2; } +bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + constexpr int num_dq_inputs = 1; + constexpr int num_q_outputs = 1; + if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { + return false; + } + + if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); + !dq_validation_status.IsOK()) { + return false; + } + + if (num_q_outputs != gsl::narrow_cast(q_nodes.size())) { + return false; + } + + const Node& dq_node = *dq_nodes.front(); + const Node& q_node = *q_nodes.front(); + + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d8fefdd8dc3d9..be7f7e0288eda 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -220,6 +220,14 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// TopK has 1 DQ input node and 1 Q output node. +// Zero point and scale are constant scalars and must match +class TopKNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 293c885858179..3f1b2f0458bc0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -38,7 +38,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Squeeze", {}}, {"Unsqueeze", {}}, {"Tile", {}}}; - } +} static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { return {{"ArgMax", {}}, @@ -129,6 +129,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { return {{"Pad", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { + return {{"TopK", {}}}; +} + /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { /* register selectors for miscellaneous ops */ @@ -229,6 +233,13 @@ void RegisterPadSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterTopKSelector(Selectors& qdq_selectors) { + /* register selector for TopK op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetTopKOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -244,6 +255,7 @@ void SelectorManager::CreateSelectors() { RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); + RegisterTopKSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 6ca36736f2f7f..047972294f78c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -63,9 +63,20 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N auto rank = input_shape.size(); auto axis = node_helper.Get("axis", -1); - if (-1 == axis && axis != static_cast(rank - 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK axis is always the last dimension"); + ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), + "QNN TopK's axis is always the last dimension"); + + // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. + // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type + // for TopK ops within the graph. However, if the TopK op **generates** a graph output, + // then we cannot support it on the HTP backend. + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu_backend) { + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), + "QNN EP does not support TopK ops that generate a graph output."); } + return Status::OK(); } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 2cca44e4d834b..63129ef2fff1e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2517,18 +2517,19 @@ TEST(QDQTransformerTests, Clip) { test_case(.04f, static_cast(-97), 1, opset, true); // [-1.24, 8.96] contrib qdq test_case(.02352941176f, static_cast(0), 0, opset); // [0, 6] test_case(.02352941176f, static_cast(0), 0, opset, true); // [0, 6] contrib qdq - test_case(9.15541313801785e-5f, static_cast(0), 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(9.15541313801785e-5f, static_cast(0), + 0, opset, true); // [0, 6] contrib 16-bit qdq test_case(0.0009f, static_cast(0), 1, opset, true); // [0, 58.98] contrib 16-bit qdq - test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] - test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq - test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] - test_case(.03f, static_cast(0), 1, opset, true); // [0, 7.65] contrib qdq - test_case(.02f, static_cast(255), 1, opset); // [-5.1, 0] - test_case(.02f, static_cast(255), 1, opset, true); // [-5.1, 0] contrib qdq - test_case(.02f, static_cast(128), 1, opset); // [-2.56, 2.54] - test_case(.02f, static_cast(128), 1, opset, true); // [-2.56, 2.54] contrib qdq - test_case(.04f, static_cast(31), 1, opset); // [-1.24, 8.96] - test_case(.04f, static_cast(31), 1, opset, true); // [-1.24, 8.96] contrib qdq + test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] + test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq + test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] + test_case(.03f, static_cast(0), 1, opset, true); // [0, 7.65] contrib qdq + test_case(.02f, static_cast(255), 1, opset); // [-5.1, 0] + test_case(.02f, static_cast(255), 1, opset, true); // [-5.1, 0] contrib qdq + test_case(.02f, static_cast(128), 1, opset); // [-2.56, 2.54] + test_case(.02f, static_cast(128), 1, opset, true); // [-2.56, 2.54] contrib qdq + test_case(.04f, static_cast(31), 1, opset); // [-1.24, 8.96] + test_case(.04f, static_cast(31), 1, opset, true); // [-1.24, 8.96] contrib qdq } // opset_version = 10 diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc new file mode 100644 index 0000000000000..93e725af5f20e --- /dev/null +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that builds a model with a TopK operator. +template +inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool cast_output_indices = true) { + return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* k_input = MakeTestInput(builder, k_def); + + NodeArg* values_output = builder.MakeOutput(); + NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // Cast indices to uint32 + if (cast_output_indices) { + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + } + }; +} + +// Runs a model with a TopK operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTopKTestOnCPU(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that TopK with a dynamic K input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, false /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK that returns the top k minimum values is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_MinValues_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("largest", static_cast(0))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test TopK on CPU backend: top 2 largest floats from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestFloats_LastAxis) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test TopK on CPU backend: top 2 largest int32s from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestInt32s_LastAxis) { + std::vector input_data = {-6, -5, -4, -3, -2, 0, 1, 2, 3, 4, 5, 6}; + RunTopKTestOnCPU(TestInputDef({1, 2, 2, 3}, false, input_data), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ TopK operator. +template +GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, k_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // K input + NodeArg* k_input = MakeTestInput(builder, k_def); + + // Reshape op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // Cast indices to uint32 (HTP backend does not support int64 graph outputs) + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +// Runs a QDQ TopK model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +// TODO: Inaccuracy detected for output 'output_0', element 6. +// Output quant params: scale=0.00061036087572574615, zero_point=32768. +// Expected val: -7.2340402603149414 +// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) +// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) +TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD)