Skip to content

Commit

Permalink
[QNN EP] Enable QNN HTP VTCM size setting (#18653)
Browse files Browse the repository at this point in the history
### Description
[QNN EP] Enable QNN HTP VTCM size setting
  • Loading branch information
HectorSVC authored Dec 1, 2023
1 parent 9c9e6ad commit ccfea55
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 55 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3598,6 +3598,7 @@ struct OrtApi {
* "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
* "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off.
* "rpc_control_latency": QNN RPC control latency.
* "vtcm_mb": QNN VTCM size in MB. default to 0(not set).
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model.
Expand Down
106 changes: 66 additions & 40 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,68 +22,70 @@ namespace onnxruntime {

constexpr const char* QNN = "QNN";

void QNNExecutionProvider::ParseProfilingLevel(std::string profiling_level_string) {
static void ParseProfilingLevel(std::string profiling_level_string,
qnn::ProfilingLevel& profiling_level) {
std::transform(profiling_level_string.begin(),
profiling_level_string.end(),
profiling_level_string.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
LOGS_DEFAULT(VERBOSE) << "profiling_level: " << profiling_level_string;
if (profiling_level_string == "off") {
profiling_level_ = qnn::ProfilingLevel::OFF;
profiling_level = qnn::ProfilingLevel::OFF;
} else if (profiling_level_string == "basic") {
profiling_level_ = qnn::ProfilingLevel::BASIC;
profiling_level = qnn::ProfilingLevel::BASIC;
} else if (profiling_level_string == "detailed") {
profiling_level_ = qnn::ProfilingLevel::DETAILED;
profiling_level = qnn::ProfilingLevel::DETAILED;
} else {
LOGS_DEFAULT(WARNING) << "Profiling level not valid.";
}
}

void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_mode_string) {
static void ParseHtpPerformanceMode(std::string htp_performance_mode_string,
qnn::HtpPerformanceMode& htp_performance_mode) {
std::transform(htp_performance_mode_string.begin(),
htp_performance_mode_string.end(),
htp_performance_mode_string.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
LOGS_DEFAULT(VERBOSE) << "Htp performance mode: " << htp_performance_mode_string;
if (htp_performance_mode_string == "burst") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpBurst;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpBurst;
} else if (htp_performance_mode_string == "balanced") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpBalanced;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpBalanced;
} else if (htp_performance_mode_string == "default") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
} else if (htp_performance_mode_string == "high_performance") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpHighPerformance;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpHighPerformance;
} else if (htp_performance_mode_string == "high_power_saver") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpHighPowerSaver;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpHighPowerSaver;
} else if (htp_performance_mode_string == "low_balanced") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpLowBalanced;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpLowBalanced;
} else if (htp_performance_mode_string == "low_power_saver") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpLowPowerSaver;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpLowPowerSaver;
} else if (htp_performance_mode_string == "power_saver") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpPowerSaver;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpPowerSaver;
} else if (htp_performance_mode_string == "sustained_high_performance") {
htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpSustainedHighPerformance;
htp_performance_mode = qnn::HtpPerformanceMode::kHtpSustainedHighPerformance;
} else {
LOGS_DEFAULT(WARNING) << "Htp performance mode not valid.";
}
}

void QNNExecutionProvider::ParseQnnContextPriority(std::string context_priority_string) {
static void ParseQnnContextPriority(std::string context_priority_string, qnn::ContextPriority& context_priority) {
std::transform(context_priority_string.begin(),
context_priority_string.end(),
context_priority_string.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
LOGS_DEFAULT(VERBOSE) << "QNN context priority: " << context_priority_string;
if (context_priority_string == "low") {
context_priority_ = qnn::ContextPriority::LOW;
context_priority = qnn::ContextPriority::LOW;
} else if (context_priority_string == "normal") {
context_priority_ = qnn::ContextPriority::NORMAL;
context_priority = qnn::ContextPriority::NORMAL;
} else if (context_priority_string == "normal_high") {
context_priority_ = qnn::ContextPriority::NORMAL_HIGH;
context_priority = qnn::ContextPriority::NORMAL_HIGH;
} else if (context_priority_string == "high") {
context_priority_ = qnn::ContextPriority::HIGH;
context_priority = qnn::ContextPriority::HIGH;
} else {
context_priority_ = qnn::ContextPriority::UNDEFINED;
context_priority = qnn::ContextPriority::UNDEFINED;
LOGS_DEFAULT(WARNING) << "QNN context priority: " << context_priority_string << " not valid, set to undefined.";
}
}
Expand Down Expand Up @@ -149,23 +151,25 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}

static const std::string PROFILING_LEVEL = "profiling_level";
qnn::ProfilingLevel profiling_level = qnn::ProfilingLevel::OFF;
auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL);
if (profiling_level_pos != provider_options_map.end()) {
ParseProfilingLevel(profiling_level_pos->second);
ParseProfilingLevel(profiling_level_pos->second, profiling_level);
}

static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency";
uint32_t rpc_control_latency = 0;
auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY);
if (latency_pos != provider_options_map.end()) {
rpc_control_latency_ = static_cast<uint32_t>(std::stoul(latency_pos->second));
LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency_;
rpc_control_latency = static_cast<uint32_t>(std::stoul(latency_pos->second));
LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency;
}

htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode";
auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE);
if (htp_performance_mode_pos != provider_options_map.end()) {
ParseHtpPerformanceMode(htp_performance_mode_pos->second);
ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode);
}

htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
Expand All @@ -185,17 +189,28 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}

static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority";
qnn::ContextPriority context_priority = qnn::ContextPriority::NORMAL;
auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY);
if (qnn_context_priority_pos != provider_options_map.end()) {
ParseQnnContextPriority(qnn_context_priority_pos->second);
ParseQnnContextPriority(qnn_context_priority_pos->second, context_priority);
}

static const std::string QNN_VTCM_MB = "vtcm_mb";
auto qnn_vtcm_mb_pos = provider_options_map.find(QNN_VTCM_MB);
if (qnn_vtcm_mb_pos != provider_options_map.end()) {
vtcm_size_in_mb_ = std::stoi(qnn_vtcm_mb_pos->second);
LOGS_DEFAULT(VERBOSE) << "vtcm_mb: " << vtcm_size_in_mb_;
if (vtcm_size_in_mb_ <= 0) {
LOGS_DEFAULT(WARNING) << "Skip invalid vtcm_mb: " << vtcm_size_in_mb_;
}
}

qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
std::move(backend_path),
profiling_level_,
rpc_control_latency_,
htp_performance_mode_,
context_priority_,
profiling_level,
rpc_control_latency,
htp_performance_mode,
context_priority,
std::move(qnn_saver_path));
}

Expand Down Expand Up @@ -480,16 +495,27 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& nod
}

void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const {
if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP &&
htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) {
QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig();
htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION;
htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG;
htp_graph_opt_config.optimizationOption.floatValue = static_cast<float>(htp_graph_finalization_opt_mode_);

QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig();
graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_opt_config.customConfig = &htp_graph_opt_config;
if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) {
if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) {
QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig();
htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION;
htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG;
htp_graph_opt_config.optimizationOption.floatValue = static_cast<float>(htp_graph_finalization_opt_mode_);

QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig();
graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_opt_config.customConfig = &htp_graph_opt_config;
}

if (vtcm_size_in_mb_ > 0) {
QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushHtpGraphCustomConfig();
htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE;
htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast<uint32_t>(vtcm_size_in_mb_);

QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushGraphConfig();
graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm;
}
}
}

Expand Down
10 changes: 1 addition & 9 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class QNNExecutionProvider : public IExecutionProvider {
DataLayout GetPreferredLayout() const override;

private:
void ParseProfilingLevel(std::string profiling_level_string);

bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
const logging::Logger& logger) const;
Expand All @@ -55,25 +53,19 @@ class QNNExecutionProvider : public IExecutionProvider {
std::vector<NodeComputeInfo>& node_compute_funcs,
const logging::Logger& logger);

void ParseHtpPerformanceMode(std::string htp_performance_mode_string);
void ParseQnnContextPriority(std::string context_priority_string);

void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string);

void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const;

private:
qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF;
qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
std::unique_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>> qnn_models_;
uint32_t rpc_control_latency_ = 0;
bool context_cache_enabled_ = false;
std::string context_cache_path_cfg_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL;
bool qnn_context_embed_mode_ = true;
int32_t vtcm_size_in_mb_ = 0;
};

} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void usage() {
"\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n"
"\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n"
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"
Expand Down Expand Up @@ -476,7 +477,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
if (supported_profiling_level.find(value) == supported_profiling_level.end()) {
ORT_THROW("Supported profiling_level: off, basic, detailed");
}
} else if (key == "rpc_control_latency") {
} else if (key == "rpc_control_latency" || key == "vtcm_mb") {
// no validation
} else if (key == "htp_performance_mode") {
std::set<std::string> supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance",
Expand Down Expand Up @@ -507,8 +508,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace perftest {
"\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n"
"\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n"
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
if (supported_profiling_level.find(value) == supported_profiling_level.end()) {
ORT_THROW("Supported profiling_level: off, basic, detailed");
}
} else if (key == "rpc_control_latency") {
} else if (key == "rpc_control_latency" || key == "vtcm_mb") {
// no validation
} else if (key == "htp_performance_mode") {
std::set<std::string> supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance",
Expand Down Expand Up @@ -374,8 +374,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down

0 comments on commit ccfea55

Please sign in to comment.