From 60ad6c6409479d2c4c0c829b5b88bdb9d2248872 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 13 Mar 2024 08:35:21 -0700 Subject: [PATCH] Enable float32 model with FP16 precision for QNN HTP backend (#19863) ### Description Enable float32 model with FP16 precision for QNN HTP backend --- .../core/session/onnxruntime_c_api.h | 4 ++++ .../providers/qnn/qnn_execution_provider.cc | 23 +++++++++++++++++++ .../providers/qnn/qnn_execution_provider.h | 1 + onnxruntime/test/onnx/main.cc | 13 ++++++++++- .../test/perftest/command_args_parser.cc | 2 ++ onnxruntime/test/perftest/ort_test_session.cc | 11 ++++++++- .../test/providers/qnn/qnn_basic_test.cc | 19 +++++++++++++++ 7 files changed, 71 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5577c840c5379..144ee1205ee1a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3619,6 +3619,10 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). + "enable_htp_fp16_precision": Only used for float32 model. + Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + - "0": Default. With fp32 precision. + - "1": With fp16 precision. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 5c4fa3e0fb88b..ef90b1f629b26 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -300,6 +300,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_HTP_FP16_MODE = "enable_htp_fp16_precision"; + auto htp_fp16_mode_pos = provider_options_map.find(QNN_HTP_FP16_MODE); + if (htp_fp16_mode_pos != provider_options_map.end()) { + if ("1" == htp_fp16_mode_pos->second) { + enable_HTP_FP16_precision_ = true; + } else if ("0" == htp_fp16_mode_pos->second) { + enable_HTP_FP16_precision_ = false; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_fp16_precision: " << enable_HTP_FP16_precision_ << " only 0 or 1 allowed. Set to 0."; + } + LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_; + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, @@ -637,6 +650,16 @@ void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -525,11 +527,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } + } else if (key == "enable_htp_fp16_precision") { + std::unordered_set supported_options = {"0", "1"}; + if (supported_options.find(value) == supported_options.end()) { + std::ostringstream str_stream; + std::copy(supported_options.begin(), supported_options.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', -'soc_model', 'htp_arch', 'device_id'])"); +'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 16c90c39f300f..93e44fd8e8d2d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -94,6 +94,8 @@ namespace perftest { "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" + "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 71d260a18ce7b..6e10763d8f293 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -382,11 +382,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } + } else if (key == "enable_htp_fp16_precision") { + std::unordered_set supported_options = {"0", "1"}; + if (supported_options.find(value) == supported_options.end()) { + std::ostringstream str_stream; + std::copy(supported_options.begin(), supported_options.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id'])"); +'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 8f07c2ce77e77..4f294f899c170 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -815,6 +815,25 @@ TEST_F(QnnHTPBackendTests, DISABLED_CastAddHTPAccuracyTest) { ExpectedEPNodeAssignment::All); } +// Test float32 model with FP16 precision +TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; + + auto input_defs = {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), + TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f)}; + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain), + provider_options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD)