diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 1d02b72342722..c7d4a236bcf89 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3604,6 +3604,7 @@ struct OrtApi { * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. + * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal". * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options: * - "0": Default. * - "1": Faster preparation time, less optimal graph. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index dd56731ac9f7f..03d6b46c528c3 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -380,15 +380,48 @@ Status QnnBackendManager::ReleaseProfilehandle() { return Status::OK(); } +Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t& qnn_context_config) { + qnn_context_config.option = QNN_CONTEXT_CONFIG_OPTION_PRIORITY; + switch (context_priority) { + case ContextPriority::LOW: { + qnn_context_config.priority = QNN_PRIORITY_LOW; + break; + } + case ContextPriority::NORMAL: { + qnn_context_config.priority = QNN_PRIORITY_NORMAL; + break; + } + case ContextPriority::NORMAL_HIGH: { + qnn_context_config.priority = QNN_PRIORITY_NORMAL_HIGH; + break; + } + case ContextPriority::HIGH: { + qnn_context_config.priority = QNN_PRIORITY_HIGH; + break; + } + case ContextPriority::UNDEFINED: { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Qnn context priority."); + } + default: + qnn_context_config.priority = QNN_PRIORITY_NORMAL; + } // switch + + return Status::OK(); +} + Status QnnBackendManager::CreateContext() { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; return Status::OK(); } + QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + auto result = qnn_interface_.contextCreate(backend_handle_, device_handle_, - (const QnnContext_Config_t**)&context_config_, + context_configs, &context_); ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context."); @@ -486,9 +519,14 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); + + QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, - (const QnnContext_Config_t**)&context_config_, + context_configs, static_cast(buffer), buffer_length, &context_, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index de5ccb5a28389..aac82c89d6f49 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -30,11 +30,13 @@ class QnnBackendManager { ProfilingLevel profiling_level, uint32_t rpc_control_latency, HtpPerformanceMode htp_performance_mode, + ContextPriority context_priority, std::string&& qnn_saver_path) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), htp_performance_mode_(htp_performance_mode), + context_priority_(context_priority), qnn_saver_path_(qnn_saver_path) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -186,7 +188,6 @@ class QnnBackendManager { Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; Qnn_ContextHandle_t context_ = nullptr; - QnnContext_Config_t** context_config_ = nullptr; ProfilingLevel profiling_level_; bool backend_initialized_ = false; bool device_created_ = false; @@ -198,6 +199,7 @@ class QnnBackendManager { std::vector op_package_paths_; uint32_t rpc_control_latency_ = 0; HtpPerformanceMode htp_performance_mode_; + ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 6080c63b555a8..66154fcf346ee 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -48,6 +48,14 @@ enum class HtpPerformanceMode : uint8_t { kHtpBalanced, }; +enum class ContextPriority : uint8_t { + LOW = 0, + NORMAL, + NORMAL_HIGH, + HIGH, + UNDEFINED +}; + // Defines the graph optimization strategy used by the HTP backend. enum class HtpGraphFinalizationOptimizationMode : uint8_t { kDefault = 0, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6cb276378a09c..8acd0d68b71d0 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -76,6 +76,26 @@ void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_m } } +void QNNExecutionProvider::ParseQnnContextPriority(std::string context_priority_string) { + std::transform(context_priority_string.begin(), + context_priority_string.end(), + context_priority_string.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + LOGS_DEFAULT(VERBOSE) << "QNN context priority: " << context_priority_string; + if (context_priority_string == "low") { + context_priority_ = qnn::ContextPriority::LOW; + } else if (context_priority_string == "normal") { + context_priority_ = qnn::ContextPriority::NORMAL; + } else if (context_priority_string == "normal_high") { + context_priority_ = qnn::ContextPriority::NORMAL_HIGH; + } else if (context_priority_string == "high") { + context_priority_ = qnn::ContextPriority::HIGH; + } else { + context_priority_ = qnn::ContextPriority::UNDEFINED; + LOGS_DEFAULT(WARNING) << "QNN context priority: " << context_priority_string << " not valid, set to undefined."; + } +} + void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string) { LOGS_DEFAULT(VERBOSE) << "HTP graph finalization optimization mode: " << htp_graph_finalization_opt_mode_string; @@ -96,16 +116,15 @@ void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std:: QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) - : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true}, - runtime_options_(provider_options_map) { + : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { if (session_options) { disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault( kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; } static const std::string CONTEXT_CACHE_ENABLED = "qnn_context_cache_enable"; - auto context_cache_enabled_pos = runtime_options_.find(CONTEXT_CACHE_ENABLED); - if (context_cache_enabled_pos != runtime_options_.end()) { + auto context_cache_enabled_pos = provider_options_map.find(CONTEXT_CACHE_ENABLED); + if (context_cache_enabled_pos != provider_options_map.end()) { if (context_cache_enabled_pos->second == "1") { context_cache_enabled_ = true; LOGS_DEFAULT(VERBOSE) << "Context cache enabled."; @@ -113,25 +132,25 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path"; - auto context_cache_path_pos = runtime_options_.find(CONTEXT_CACHE_PATH); - if (context_cache_path_pos != runtime_options_.end()) { + auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH); + if (context_cache_path_pos != provider_options_map.end()) { context_cache_path_ = context_cache_path_pos->second; LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_; } bool qnn_context_embed_mode = true; static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode"; - auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE); - if (context_cache_embed_mode_pos != runtime_options_.end()) { + auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE); + if (context_cache_embed_mode_pos != provider_options_map.end()) { qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1"; LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode; } static const std::string BACKEND_PATH = "backend_path"; - auto backend_path_pos = runtime_options_.find(BACKEND_PATH); + auto backend_path_pos = provider_options_map.find(BACKEND_PATH); std::string backend_path; - if (backend_path_pos != runtime_options_.end()) { + if (backend_path_pos != provider_options_map.end()) { backend_path = backend_path_pos->second; LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path; } else { @@ -139,46 +158,53 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string PROFILING_LEVEL = "profiling_level"; - auto profiling_level_pos = runtime_options_.find(PROFILING_LEVEL); - if (profiling_level_pos != runtime_options_.end()) { + auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); + if (profiling_level_pos != provider_options_map.end()) { ParseProfilingLevel(profiling_level_pos->second); } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; - auto latency_pos = runtime_options_.find(RPC_CONTROL_LANTENCY); - if (latency_pos != runtime_options_.end()) { + auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); + if (latency_pos != provider_options_map.end()) { rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency_; } htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; - auto htp_performance_mode_pos = runtime_options_.find(HTP_PERFORMANCE_MODE); - if (htp_performance_mode_pos != runtime_options_.end()) { + auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); + if (htp_performance_mode_pos != provider_options_map.end()) { ParseHtpPerformanceMode(htp_performance_mode_pos->second); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; static const std::string HTP_GRAPH_FINALIZATION_OPT_MODE = "htp_graph_finalization_optimization_mode"; - auto htp_graph_finalization_opt_mode_pos = runtime_options_.find(HTP_GRAPH_FINALIZATION_OPT_MODE); - if (htp_graph_finalization_opt_mode_pos != runtime_options_.end()) { + auto htp_graph_finalization_opt_mode_pos = provider_options_map.find(HTP_GRAPH_FINALIZATION_OPT_MODE); + if (htp_graph_finalization_opt_mode_pos != provider_options_map.end()) { ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second); } // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; std::string qnn_saver_path; - auto qnn_saver_path_pos = runtime_options_.find(QNN_SAVER_PATH_KEY); - if (qnn_saver_path_pos != runtime_options_.end()) { + auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY); + if (qnn_saver_path_pos != provider_options_map.end()) { qnn_saver_path = qnn_saver_path_pos->second; LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path; } + static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority"; + auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY); + if (qnn_context_priority_pos != provider_options_map.end()) { + ParseQnnContextPriority(qnn_context_priority_pos->second); + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level_, rpc_control_latency_, htp_performance_mode_, + context_priority_, std::move(qnn_saver_path)); qnn_cache_model_handler_ = std::make_unique(qnn_context_embed_mode); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index a01b828531555..cf0bff8890d0c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,13 +57,13 @@ class QNNExecutionProvider : public IExecutionProvider { const logging::Logger& logger); void ParseHtpPerformanceMode(std::string htp_performance_mode_string); + void ParseQnnContextPriority(std::string context_priority_string); void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const; private: - ProviderOptions runtime_options_; qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF; qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -74,6 +74,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::string context_cache_path_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. std::unique_ptr qnn_cache_model_handler_; + qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL; }; } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 98646058eec3d..2c0804397cfe8 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -56,6 +56,7 @@ void usage() { "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" "\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n" "\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" @@ -488,6 +489,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_performance_mode. select from: " + str); } + } else if (key == "qnn_context_priority") { + std::set supported_qnn_context_priority = {"low", "normal", "normal_high", "high"}; + if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { + ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); + } } else if (key == "qnn_saver_path") { // no validation } else if (key == "htp_graph_finalization_optimization_mode") { @@ -502,7 +508,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', 'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path', -'htp_graph_finalization_optimization_mode'])"); +'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 72472e5798792..a72a0d105eefc 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -34,8 +34,8 @@ namespace perftest { "\t-A: Disable memory arena\n" "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" - "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " - "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. " + "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " + "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. " "Default:'cpu'.\n" "\t-b [tf|ort]: backend to use. Default:ort\n" "\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n" @@ -71,6 +71,7 @@ namespace perftest { "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f3ea188043dbe..c2dd81ec9f359 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -367,10 +367,15 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } + } else if (key == "qnn_context_priority") { + std::set supported_qnn_context_priority = {"low", "normal", "normal_high", "high"}; + if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { + ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', 'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path', -'htp_graph_finalization_optimization_mode'])"); +'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 02ff834169b2b..2e2acb36e8071 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -174,7 +174,8 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output // types and shapes. static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, - std::string htp_graph_finalization_opt_mode = "") { + std::string htp_graph_finalization_opt_mode = "", + std::string qnn_context_priority = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -199,6 +200,10 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo options["htp_graph_finalization_optimization_mode"] = std::move(htp_graph_finalization_opt_mode); } + if (!qnn_context_priority.empty()) { + options["qnn_context_priority"] = std::move(qnn_context_priority); + } + so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); @@ -322,6 +327,15 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { } } +// Test that models run with high QNN context priority. +TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "high"); // qnn_context_priority +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD)