diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 82d71bb3e9dde..5e2b3f6113b28 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -223,6 +223,7 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper, Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; union { float alpha; + uint16_t alpha_fp16; uint8_t unpack[sizeof(float)]; } tensor_data; tensor_data.alpha = node_helper.Get("alpha", 0.01f); @@ -240,7 +241,17 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper, quantize_param = QnnQuantParamsWrapper(scale, static_cast(zero_point)); qnn_data_type = QNN_DATATYPE_UFIXED_POINT_8; } else { - unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(float)); + const auto& inputs = node_unit.Inputs(); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info)); + // QNN requires alpha is fp16 when input is fp16 + if (input_info.qnn_data_type == QNN_DATATYPE_FLOAT_16) { + tensor_data.alpha_fp16 = MLFloat16(tensor_data.alpha).val; + qnn_data_type = QNN_DATATYPE_FLOAT_16; + unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(MLFloat16)); + } else { + unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(float)); + } } std::vector input_shape{1}; Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_STATIC; diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index e3077ec569923..ece8d91d53648 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -58,6 +58,25 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) { ExpectedEPNodeAssignment::All); } +// Test Leaky Relu where input is FP16 and alpha is FP32 +TEST_F(QnnHTPBackendTests, LeakyReluFP16OpSet16) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto input_def = TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 1.0f, 10.0f, 30.0f, 40.0f}); + TestInputDef input_fp16_def = ConvertToFP16InputDef(input_def); + auto attrs = {utils::MakeAttribute("alpha", 0.2f)}; + TestFp16ModelAccuracy(BuildOpTestCase("LeakyRelu", {input_def}, {}, attrs), + BuildOpTestCase("LeakyRelu", {input_fp16_def}, {}, attrs), + provider_options, + 16, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test } // namespace onnxruntime