diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index ce87ac4a3d21c..caf4725626338 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -170,9 +170,11 @@ static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node return false; } - // check Q/DQ have same scale type and different zero point type - return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && - (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); + // For scale, ensure that the Q/DQ have same scale type. + // + // For zero-point: we previously only fused (DQ -> Q) into a Convert op if the quantization types differed. + // However, a single Convert op is faster than (DQ -> Q), so we should always fuse regardless of the zero-point type. + return (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); } } // namespace qnn diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bb77c92668853..7f55a44c748b6 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -33,16 +33,37 @@ struct QuantParams { float scale; QType zero_point; + inline std::pair CalcRminRmax() const { + constexpr float qmin = static_cast(std::numeric_limits::min()); + constexpr float qmax = static_cast(std::numeric_limits::max()); + const float qrange = (qmax - qmin); + const float rrange = this->scale * qrange; + const float rmin = -(static_cast(this->zero_point) - qmin) * this->scale; + const float rmax = rrange + rmin; + + return {rmin, rmax}; + } + + inline bool IsSymmetric() const { + constexpr float qmin = static_cast(std::numeric_limits::min()); + constexpr float qmax = static_cast(std::numeric_limits::max()); + float init_zero_point = (qmin + qmax) / 2.0; + const QType symm_zero_point = static_cast(RoundHalfToEven( + std::max(qmin, std::min(qmax, init_zero_point)))); + + return this->zero_point == symm_zero_point; + } + static QuantParams Compute(float rmin, float rmax, bool symmetric = false) { return Compute( rmin, rmax, - static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()), + std::numeric_limits::min(), + std::numeric_limits::max(), symmetric); } - static QuantParams Compute(float rmin, float rmax, float qmin, float qmax, bool symmetric = false) { + static QuantParams Compute(float rmin, float rmax, QType qmin, QType qmax, bool symmetric = false) { // Ensure a minimum range of 0.0001 (required by QNN) rmax = std::max(rmax, rmin + 0.0001f); @@ -56,8 +77,8 @@ struct QuantParams { rmin = -abs_max; } - float qmin_flt = qmin; - float qmax_flt = qmax; + const float qmin_flt = qmin; + const float qmax_flt = qmax; const float scale = (rmax - rmin) / (qmax_flt - qmin_flt); float initial_zero_point = 0.0f; @@ -76,6 +97,13 @@ struct QuantParams { } }; +// Utitity that converts quantization parameters from one type to another (e.g., uint8 to uint16). +template +inline QuantParams ConvertQuantParams(QuantParams src_qparams) { + std::pair src_rmin_rmax = src_qparams.CalcRminRmax(); + return QuantParams::Compute(src_rmin_rmax.first, src_rmin_rmax.second, src_qparams.IsSymmetric()); +} + // Signature for function that builds a QDQ model. // The parameter `output_qparams` contains quantization parameters that *can* be used for the QDQ model output. // These output quantization parameters are computed by first running the float32 model and determining the diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 83899ec6ef17b..8de414dbb4a62 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -157,8 +157,6 @@ static void RunOpTest(const std::string& op_type, if (enable_htp_fp16_precision) { provider_options["enable_htp_fp16_precision"] = "1"; - } else { - provider_options["enable_htp_fp16_precision"] = "0"; // enabled in QNN EP by default } // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. @@ -1208,6 +1206,80 @@ TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { ExpectedEPNodeAssignment::All); } +// Builds a graph where a (DQ -> Q) sequence at the graph's output is fuse into a QNN Convert operator. +// ONNX Graph: DQ -> Add -> Q -> DQ -> Q -> graph_output +// QNN Graph: DQ -> Add -> Q -> Convert -> graph_output +template +static GetTestModelFn BuildDQQConvertAtOutputTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def, + const QuantParams& output_qparams) { + return [input0_def, input1_def, output_qparams](ModelTestBuilder& builder) { + // Input0 -> Quantize(InQuantType) -> Dequantize(InQuantType to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point); + + // Input1 -> Quantize(InQuantType) -> Dequantize(InQuantType to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_qdq, input1_after_qdq}, {op_output}); + + // op_output -> Quantize(InQuantType) -> add_out_q + QuantParams add_out_qparams = ConvertQuantParams(output_qparams); + add_out_qparams.scale *= 1.01f; // Make qparams slightly different so DQ->Q are not optimized out. + NodeArg* add_out_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(op_output, add_out_qparams.scale, + add_out_qparams.zero_point, add_out_q); + + // Add DQ + NodeArg* add_out_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_out_q, add_out_qparams.scale, + add_out_qparams.zero_point, add_out_dq); + + // Add a Q to quantize to OutQuantType + // The previous DQ and this Q will be fused into a QNN Convert. + NodeArg* q_conv_out = builder.MakeOutput(); + builder.AddQuantizeLinearNode(add_out_dq, output_qparams.scale, output_qparams.zero_point, + q_conv_out); + }; +} + +// Test fusion of (DQ -> Q) into QNN's Convert op using the same quant type. +TEST_F(QnnHTPBackendTests, DQ_Q_ConvertFusion_SameType) { + std::vector input0_data = {-8.0f, -6.0, -2.0f, 0.0f, 2.0f, 4.0f, 6.0f, 8.0f}; + std::vector input1_data = {-8.0f, -6.0, -2.0f, 0.0f, 2.0f, 4.0f, 6.0f, 8.0f}; + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + QuantParams out_qparams_u8 = {1.0f, 128}; + QuantParams out_qparams_u16 = {1.0f, 32768}; + + // QNN Convert op converts uint8 to uint8 at the graph output. Slightly different scale values. + RunQnnModelTest(BuildDQQConvertAtOutputTestCase(input0_def, input1_def, out_qparams_u8), + provider_options, + 21, + ExpectedEPNodeAssignment::All); + + // QNN Convert op converts uint16 to uint16 at the graph output. Slightly different scale values. + RunQnnModelTest(BuildDQQConvertAtOutputTestCase(input0_def, input1_def, out_qparams_u16), + provider_options, + 21, + ExpectedEPNodeAssignment::All); +} + TEST_F(QnnHTPBackendTests, UnaryOp_HardSigmoid_QU8) { RunQDQOpTest("HardSigmoid", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index 6bc0f8d105495..b118c8faec0f7 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -38,6 +38,10 @@ void VerifyOutput(const std::string& output_name, EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) << " mismatch for " << output_name; break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) + << " mismatch for " << output_name; + break; case ONNX_NAMESPACE::TensorProto_DataType_UINT8: EXPECT_TRUE(SpanEq(expected_tensor.DataAsSpan(), tensor.DataAsSpan())) << " mismatch for " << output_name;