Skip to content

Commit

Permalink
[QNN EP] Support prelu fp16 (#20428)
Browse files Browse the repository at this point in the history
### Description
Originally, Prelu in QNN will fail when the input is fp16 and alpha is fp32.
QNN requires alpha is fp16 when input is fp16.
This can be resolved by casting alpha to fp16 and pass it to QNN.

### Motivation and Context
Makes QNN Prelu support fp16 case.

---------

Co-authored-by: Hector Li <[email protected]>
  • Loading branch information
winskuo-quic and HectorSVC authored Apr 29, 2024
1 parent 358f5bb commit 509cbca
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper,
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
union {
float alpha;
uint16_t alpha_fp16;
uint8_t unpack[sizeof(float)];
} tensor_data;
tensor_data.alpha = node_helper.Get("alpha", 0.01f);
Expand All @@ -240,7 +241,17 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper,
quantize_param = QnnQuantParamsWrapper(scale, static_cast<int32_t>(zero_point));
qnn_data_type = QNN_DATATYPE_UFIXED_POINT_8;
} else {
unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(float));
const auto& inputs = node_unit.Inputs();
TensorInfo input_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info));
// QNN requires alpha is fp16 when input is fp16
if (input_info.qnn_data_type == QNN_DATATYPE_FLOAT_16) {
tensor_data.alpha_fp16 = MLFloat16(tensor_data.alpha).val;
qnn_data_type = QNN_DATATYPE_FLOAT_16;
unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(MLFloat16));
} else {
unpacked_data.assign(tensor_data.unpack, tensor_data.unpack + sizeof(float));
}
}
std::vector<uint32_t> input_shape{1};
Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_STATIC;
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) {
ExpectedEPNodeAssignment::All);
}

// Test Leaky Relu where input is FP16 and alpha is FP32
TEST_F(QnnHTPBackendTests, LeakyReluFP16OpSet16) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

auto input_def = TestInputDef<float>({1, 2, 3}, false, {-40.0f, -20.0f, 1.0f, 10.0f, 30.0f, 40.0f});
TestInputDef<MLFloat16> input_fp16_def = ConvertToFP16InputDef(input_def);
auto attrs = {utils::MakeAttribute("alpha", 0.2f)};
TestFp16ModelAccuracy(BuildOpTestCase<float>("LeakyRelu", {input_def}, {}, attrs),
BuildOpTestCase<MLFloat16>("LeakyRelu", {input_fp16_def}, {}, attrs),
provider_options,
16,
ExpectedEPNodeAssignment::All);
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
} // namespace test
} // namespace onnxruntime
Expand Down

0 comments on commit 509cbca

Please sign in to comment.