From 84d48b6ad66fdc259c33876af4f83c0fef3b6d87 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 16 Oct 2024 15:00:53 -0700 Subject: [PATCH] [QNN EP] Add provider option to offload graph I/O quantization/dequantization to the CPU EP (#22436) ### Description Adds QNN provider option `offload_graph_io_quantization` to offload graph input quantization and graph output dequantization to the CPU EP. Option is disabled by default to maintain current behavior. ### Motivation and Context Offloading the handling of I/O quantization to the CPU EP significantly improves inference latency for many models. --- .../core/session/onnxruntime_c_api.h | 18 +++-- .../builder/opbuilder/simple_op_builder.cc | 10 +++ .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../core/providers/qnn/builder/qnn_model.h | 1 + .../providers/qnn/builder/qnn_model_wrapper.h | 13 +++- .../providers/qnn/qnn_execution_provider.cc | 32 +++++++- .../providers/qnn/qnn_execution_provider.h | 1 + onnxruntime/test/onnx/main.cc | 8 +- .../test/perftest/command_args_parser.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 6 +- .../test/providers/qnn/qnn_basic_test.cc | 75 +++++++++++++++++++ .../test/providers/qnn/qnn_test_utils.cc | 7 +- .../test/providers/qnn/qnn_test_utils.h | 11 ++- .../test/qnn_ctx_gen/command_args_parser.cc | 8 +- 14 files changed, 172 insertions(+), 24 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9e71997c1e442..bde27df94ed1c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3651,13 +3651,17 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Used for float32 model for HTP backend. - Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": With fp32 precision. - - "1": Default. With fp16 precision. - "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - - "0": Default. Disabled. - - "1": Enabled. + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * - "0": With fp32 precision. + * - "1": Default. With fp16 precision. + * "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. + * - "0": Default. Disabled. + * - "1": Enabled. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 0358fae3c2115..a6c4203ad92e4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -164,6 +164,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Inputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone DQ op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(node_unit.Outputs()[0].node_arg.Name()), + "QNN EP is configured to not take DQ nodes that generate a graph output."); + } } if (op_type == "QuantizeLinear") { @@ -171,6 +176,11 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, int64_t quant_axis = 0; ORT_RETURN_IF_ERROR(qnn_model_wrapper.IsPerChannelQuantized(node_unit.Outputs()[0], is_per_chan_quant, quant_axis)); ORT_RETURN_IF(is_per_chan_quant, "QNN EP does not support a standalone Q op with per-channel quantization"); + + if (qnn_model_wrapper.GetModelSettings().offload_graph_io_quantization) { + ORT_RETURN_IF(qnn_model_wrapper.IsGraphInput(node_unit.Inputs()[0].node_arg.Name()), + "QNN EP is configured to not take Q nodes that consume a graph input."); + } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index f322456e0c8f0..b09ff51b666c7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -95,6 +95,7 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node, Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs) { LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); @@ -115,7 +116,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, model_input_index_map_, model_output_index_map_, initializer_inputs_, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings); bool rt = true; rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name, graph_configs); if (!rt) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 83cf8f9f08fb0..d9682cc3b3222 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -35,6 +35,7 @@ class QnnModel { Status ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const qnn::ModelSettings& model_settings, const logging::Logger& logger, const QnnGraph_Config_t** graph_configs = nullptr); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ab122b7f8e28..f3e52050e79e0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -29,6 +29,10 @@ struct TensorInfo { const ONNX_NAMESPACE::TensorProto* initializer_tensor; }; +struct ModelSettings { + bool offload_graph_io_quantization = false; +}; + class QnnModelWrapper { public: QnnModelWrapper(const GraphViewer& graph_viewer, @@ -38,7 +42,8 @@ class QnnModelWrapper { const std::unordered_map& input_index_map, const std::unordered_map& output_index_map, const std::unordered_set& initializer_lookup, - QnnBackendType qnn_backend_type) + QnnBackendType qnn_backend_type, + const ModelSettings& model_settings) : graph_viewer_(graph_viewer), logger_(logger), qnn_interface_(qnn_interface), @@ -46,12 +51,15 @@ class QnnModelWrapper { input_index_map_(input_index_map), output_index_map_(output_index_map), initializer_lookup_(initializer_lookup), - qnn_backend_type_(qnn_backend_type) { + qnn_backend_type_(qnn_backend_type), + model_settings_(model_settings) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModelWrapper); ~QnnModelWrapper() = default; + const ModelSettings& GetModelSettings() const { return model_settings_; } + bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); @@ -279,6 +287,7 @@ class QnnModelWrapper { const std::unordered_map& output_index_map_; const std::unordered_set& initializer_lookup_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + ModelSettings model_settings_ = {}; }; // QnnModelWrapper } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 24132b98e3757..4cd5d403e95b8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -161,6 +161,23 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic } } +static bool ParseBoolOption(const std::string& key, bool default_value, + const std::unordered_map& options) { + bool result = default_value; + auto it = options.find(key); + if (it != options.end()) { + if ("1" == it->second) { + result = true; + } else if ("0" == it->second) { + result = false; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid value for " << key << " (" << it->second << "). Only 0 or 1 allowed."; + } + LOGS_DEFAULT(VERBOSE) << "Using " << key << ": " << result; + } + return result; +} + qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned char level) { if (level == 5) { LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level); @@ -403,6 +420,15 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; } + model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, + provider_options_map); + + if (disable_cpu_ep_fallback_ && model_settings_.offload_graph_io_quantization) { + LOGS_DEFAULT(WARNING) << "Fallback to CPU EP is disabled, but user configured QNN EP to offload graph I/O " + << "quantization/dequantization to another EP. Session creation will fail if the CPU EP " + << "handles the graph I/O quantization/dequantization."; + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level_etw, @@ -499,7 +525,8 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, model_input_index_map, model_output_index_map, initializer_input_lookup, - qnn_backend_manager_->GetQnnBackendType()); + qnn_backend_manager_->GetQnnBackendType(), + model_settings_); std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_size); @@ -845,7 +872,8 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vectorComposeGraph(graph_viewer, fused_node, logger, graph_configs_builder.GetQnnConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, model_settings_, logger, + graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e0eaf31c94a36..246ab1d5a6608 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -153,6 +153,7 @@ class QNNExecutionProvider : public IExecutionProvider { #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; #endif + qnn::ModelSettings model_settings_ = {}; class PerThreadContext final { public: diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6d86e4c35af85..93a1bf9f30651 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -77,6 +77,8 @@ void usage() { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -587,20 +589,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', -'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); +'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 94945c0393d08..e40544d950ed7 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -98,6 +98,8 @@ namespace perftest { "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fcdef48eda56c..e69c87b2540e5 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -302,20 +302,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), std::ostream_iterator(str_stream, ",")); std::string str = str_stream.str(); - ORT_THROW("Wrong value for " + key + ". select from: " + str); + ORT_THROW("Wrong value for ", key, ". select from: ", str); } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id', 'enable_htp_fp16_precision'])"); +'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 236b66a2d8a78..e8282dbad9f72 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1023,6 +1023,81 @@ TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { &ep_graph_checker); } +// Test option for offloading quantization of graph inputs and dequantization of graph outputs to the CPU EP. +TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { + // Returns a function that checks that the Q/DQ ops at the graph IO boundary are offloaded to CPU + // if the corresponding provider option is enabled. + auto graph_checker_builder = [](bool offload_graph_io_quantization) -> std::function { + return [offload_graph_io_quantization](const Graph& graph) { + size_t num_q = 0; + size_t num_dq = 0; + size_t num_qnn_fused_node = 0; + + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + + if (offload_graph_io_quantization && op_type == "QuantizeLinear") { + const bool consumes_graph_input = graph.IsInputsIncludingInitializers(node.InputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(consumes_graph_input); + num_q += 1; + } else if (offload_graph_io_quantization && op_type == "DequantizeLinear") { + const bool produces_graph_output = graph.IsOutput(node.OutputDefs()[0]); + EXPECT_EQ(ep_name, kCpuExecutionProvider); + EXPECT_TRUE(produces_graph_output); + num_dq += 1; + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + num_qnn_fused_node += 1; + } + } + + EXPECT_EQ(num_q, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_dq, static_cast(offload_graph_io_quantization)); + EXPECT_EQ(num_qnn_fused_node, 1); + }; + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::vector op_types = { + "Sigmoid", + "Transpose", + "Softmax", + "Sqrt", + "Elu", + }; + + // Test various QDQ ops with offloading of I/O quantization enabled and disabled. + for (auto op_type : op_types) { + for (int offload_io_quant = 0; offload_io_quant <= 1; offload_io_quant++) { + provider_options["offload_graph_io_quantization"] = offload_io_quant ? "1" : "0"; + auto graph_checker = graph_checker_builder(offload_io_quant); + auto expected_ep_assignment = offload_io_quant ? ExpectedEPNodeAssignment::Some : ExpectedEPNodeAssignment::All; + + float min_val = (op_type == "Sqrt") ? 0.0f : -10.0f; + TestInputDef input_def({1, 2, 2, 2}, false, GetFloatDataInRange(min_val, 10.0f, 8)); + auto f32_model_build_fn = BuildOpTestCase(op_type, {input_def}, {}, {}); + auto qdq_model_build_fn = BuildQDQOpTestCase(op_type, {input_def}, {}, {}); + TestQDQModelAccuracy(f32_model_build_fn, + qdq_model_build_fn, + provider_options, + /*opset*/ 21, + expected_ep_assignment, + /*abs_err*/ QDQTolerance(), + logging::Severity::kERROR, + /*qnn_ctx_model_path*/ "", + /*session_option_pairs*/ {}, + &graph_checker); + } + } +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 8a4f7f2a1f6b5..79e7d39e85518 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -134,7 +134,8 @@ void InferenceModel(const std::string& model_data, const char* log_id, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep, - const std::unordered_map& session_option_pairs) { + const std::unordered_map& session_option_pairs, + std::function* graph_checker) { SessionOptions so; so.session_logid = log_id; for (auto key_value : session_option_pairs) { @@ -166,6 +167,10 @@ void InferenceModel(const std::string& model_data, const char* log_id, ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type; } + if (graph_checker) { + (*graph_checker)(graph); + } + const auto& outputs = graph.GetOutputs(); std::vector output_names; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 7f55a44c748b6..a8670252ff9e0 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -457,13 +457,15 @@ DEF_QUANTIZE_VALUES_INT4_FUNC(UInt4x2, ParQuantizeLinearStdU4) * \param output_vals Initialized to the inference results. * \param is_qnn_ep Ture: QNN EP is used. False: CPU EP is used (default). * \param session_option_pairs extra session options. + * \param graph_checker Function called on the Graph. */ void InferenceModel(const std::string& model_data, const char* log_id, const ProviderOptions& provider_options, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals, bool is_qnn_ep = false, - const std::unordered_map& session_option_pairs = {}); + const std::unordered_map& session_option_pairs = {}, + std::function* graph_checker = nullptr); /** * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies @@ -515,6 +517,8 @@ struct QDQTolerance { * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the QDQ model * on CPU EP. This tolerance is a percentage of the output range. * \param log_severity The logger's severity setting. + * \param ep_graph_checker Function called on the Graph generated for the QNN EP's session. Used to check node + * EP assignment. */ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, @@ -523,7 +527,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe QDQTolerance tolerance = QDQTolerance(), logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", - const std::unordered_map& session_option_pairs = {}) { + const std::unordered_map& session_option_pairs = {}, + std::function* qnn_ep_graph_checker = nullptr) { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; @@ -607,7 +612,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP InferenceModel(qdq_model_data, "qdq_model_logger", qnn_options, expected_ep_assignment, - qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); + qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs, qnn_ep_graph_checker); } if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 102846e08ac5f..5b3720992c542 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -48,6 +48,8 @@ namespace qnnctxgen { "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -143,7 +145,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing") { + } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" || + key == "offload_graph_io_quantization") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -154,7 +157,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode', - 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing'])"); + 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing', + 'offload_graph_io_quantization'])"); } test_config.run_config.qnn_options[key] = value;