diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cddad732104ed..c41700453a73b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3598,6 +3598,7 @@ struct OrtApi { * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided. * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. * "rpc_control_latency": QNN RPC control latency. + * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model. diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index c7b309ae471c9..60f7bbe08cb6a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -22,68 +22,70 @@ namespace onnxruntime { constexpr const char* QNN = "QNN"; -void QNNExecutionProvider::ParseProfilingLevel(std::string profiling_level_string) { +static void ParseProfilingLevel(std::string profiling_level_string, + qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), profiling_level_string.end(), profiling_level_string.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); LOGS_DEFAULT(VERBOSE) << "profiling_level: " << profiling_level_string; if (profiling_level_string == "off") { - profiling_level_ = qnn::ProfilingLevel::OFF; + profiling_level = qnn::ProfilingLevel::OFF; } else if (profiling_level_string == "basic") { - profiling_level_ = qnn::ProfilingLevel::BASIC; + profiling_level = qnn::ProfilingLevel::BASIC; } else if (profiling_level_string == "detailed") { - profiling_level_ = qnn::ProfilingLevel::DETAILED; + profiling_level = qnn::ProfilingLevel::DETAILED; } else { LOGS_DEFAULT(WARNING) << "Profiling level not valid."; } } -void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_mode_string) { +static void ParseHtpPerformanceMode(std::string htp_performance_mode_string, + qnn::HtpPerformanceMode& htp_performance_mode) { std::transform(htp_performance_mode_string.begin(), htp_performance_mode_string.end(), htp_performance_mode_string.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); LOGS_DEFAULT(VERBOSE) << "Htp performance mode: " << htp_performance_mode_string; if (htp_performance_mode_string == "burst") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpBurst; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpBurst; } else if (htp_performance_mode_string == "balanced") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpBalanced; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpBalanced; } else if (htp_performance_mode_string == "default") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; } else if (htp_performance_mode_string == "high_performance") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpHighPerformance; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpHighPerformance; } else if (htp_performance_mode_string == "high_power_saver") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpHighPowerSaver; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpHighPowerSaver; } else if (htp_performance_mode_string == "low_balanced") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpLowBalanced; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpLowBalanced; } else if (htp_performance_mode_string == "low_power_saver") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpLowPowerSaver; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpLowPowerSaver; } else if (htp_performance_mode_string == "power_saver") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpPowerSaver; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpPowerSaver; } else if (htp_performance_mode_string == "sustained_high_performance") { - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpSustainedHighPerformance; + htp_performance_mode = qnn::HtpPerformanceMode::kHtpSustainedHighPerformance; } else { LOGS_DEFAULT(WARNING) << "Htp performance mode not valid."; } } -void QNNExecutionProvider::ParseQnnContextPriority(std::string context_priority_string) { +static void ParseQnnContextPriority(std::string context_priority_string, qnn::ContextPriority& context_priority) { std::transform(context_priority_string.begin(), context_priority_string.end(), context_priority_string.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); LOGS_DEFAULT(VERBOSE) << "QNN context priority: " << context_priority_string; if (context_priority_string == "low") { - context_priority_ = qnn::ContextPriority::LOW; + context_priority = qnn::ContextPriority::LOW; } else if (context_priority_string == "normal") { - context_priority_ = qnn::ContextPriority::NORMAL; + context_priority = qnn::ContextPriority::NORMAL; } else if (context_priority_string == "normal_high") { - context_priority_ = qnn::ContextPriority::NORMAL_HIGH; + context_priority = qnn::ContextPriority::NORMAL_HIGH; } else if (context_priority_string == "high") { - context_priority_ = qnn::ContextPriority::HIGH; + context_priority = qnn::ContextPriority::HIGH; } else { - context_priority_ = qnn::ContextPriority::UNDEFINED; + context_priority = qnn::ContextPriority::UNDEFINED; LOGS_DEFAULT(WARNING) << "QNN context priority: " << context_priority_string << " not valid, set to undefined."; } } @@ -149,23 +151,25 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string PROFILING_LEVEL = "profiling_level"; + qnn::ProfilingLevel profiling_level = qnn::ProfilingLevel::OFF; auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { - ParseProfilingLevel(profiling_level_pos->second); + ParseProfilingLevel(profiling_level_pos->second, profiling_level); } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; + uint32_t rpc_control_latency = 0; auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); if (latency_pos != provider_options_map.end()) { - rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); - LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency_; + rpc_control_latency = static_cast(std::stoul(latency_pos->second)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } - htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); if (htp_performance_mode_pos != provider_options_map.end()) { - ParseHtpPerformanceMode(htp_performance_mode_pos->second); + ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -185,17 +189,28 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority"; + qnn::ContextPriority context_priority = qnn::ContextPriority::NORMAL; auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY); if (qnn_context_priority_pos != provider_options_map.end()) { - ParseQnnContextPriority(qnn_context_priority_pos->second); + ParseQnnContextPriority(qnn_context_priority_pos->second, context_priority); + } + + static const std::string QNN_VTCM_MB = "vtcm_mb"; + auto qnn_vtcm_mb_pos = provider_options_map.find(QNN_VTCM_MB); + if (qnn_vtcm_mb_pos != provider_options_map.end()) { + vtcm_size_in_mb_ = std::stoi(qnn_vtcm_mb_pos->second); + LOGS_DEFAULT(VERBOSE) << "vtcm_mb: " << vtcm_size_in_mb_; + if (vtcm_size_in_mb_ <= 0) { + LOGS_DEFAULT(WARNING) << "Skip invalid vtcm_mb: " << vtcm_size_in_mb_; + } } qnn_backend_manager_ = std::make_unique( std::move(backend_path), - profiling_level_, - rpc_control_latency_, - htp_performance_mode_, - context_priority_, + profiling_level, + rpc_control_latency, + htp_performance_mode, + context_priority, std::move(qnn_saver_path)); } @@ -480,16 +495,27 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod } void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const { - if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP && - htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig(); - htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); - - QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig(); - graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_opt_config.customConfig = &htp_graph_opt_config; + if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { + if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig(); + htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; + htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; + htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); + + QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig(); + graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_opt_config.customConfig = &htp_graph_opt_config; + } + + if (vtcm_size_in_mb_ > 0) { + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushHtpGraphCustomConfig(); + htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; + htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast(vtcm_size_in_mb_); + + QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushGraphConfig(); + graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm; + } } } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 8c99a916a6f69..8b5d0929209ee 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -36,8 +36,6 @@ class QNNExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; private: - void ParseProfilingLevel(std::string profiling_level_string); - bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; @@ -55,25 +53,19 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs, const logging::Logger& logger); - void ParseHtpPerformanceMode(std::string htp_performance_mode_string); - void ParseQnnContextPriority(std::string context_priority_string); - void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const; private: - qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF; - qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; std::unique_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; - uint32_t rpc_control_latency_ = 0; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. - qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL; bool qnn_context_embed_mode_ = true; + int32_t vtcm_size_in_mb_ = 0; }; } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 2c0804397cfe8..646ff7c95b229 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -54,6 +54,7 @@ void usage() { "\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" @@ -476,7 +477,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -507,8 +508,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', -'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path', -'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index a72a0d105eefc..27e26fe0b3c45 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -69,6 +69,7 @@ namespace perftest { "\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index c2dd81ec9f359..eb2a77c07f803 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -343,7 +343,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -374,8 +374,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', -'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path', -'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } qnn_options[key] = value;