Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Enable option to set QNN context priority #18315

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3604,6 +3604,7 @@ struct OrtApi {
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
* may alter model/EP partitioning. Use only for debugging.
* "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal".
* "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options:
* - "0": Default.
* - "1": Faster preparation time, less optimal graph.
Expand Down
42 changes: 40 additions & 2 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,48 @@ Status QnnBackendManager::ReleaseProfilehandle() {
return Status::OK();
}

Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t& qnn_context_config) {
qnn_context_config.option = QNN_CONTEXT_CONFIG_OPTION_PRIORITY;
switch (context_priority) {
case ContextPriority::LOW: {
qnn_context_config.priority = QNN_PRIORITY_LOW;
break;
}
case ContextPriority::NORMAL: {
qnn_context_config.priority = QNN_PRIORITY_NORMAL;
break;
}
case ContextPriority::NORMAL_HIGH: {
qnn_context_config.priority = QNN_PRIORITY_NORMAL_HIGH;
break;
}
case ContextPriority::HIGH: {
qnn_context_config.priority = QNN_PRIORITY_HIGH;
break;
}
case ContextPriority::UNDEFINED: {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Qnn context priority.");
}
default:
qnn_context_config.priority = QNN_PRIORITY_NORMAL;
} // switch

return Status::OK();
}

Status QnnBackendManager::CreateContext() {
if (true == context_created_) {
LOGS_DEFAULT(INFO) << "Context created already.";
return Status::OK();
}

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

auto result = qnn_interface_.contextCreate(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
context_configs,
&context_);

ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context.");
Expand Down Expand Up @@ -486,9 +519,14 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
context_configs,
static_cast<void*>(buffer),
buffer_length,
&context_,
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ class QnnBackendManager {
ProfilingLevel profiling_level,
uint32_t rpc_control_latency,
HtpPerformanceMode htp_performance_mode,
ContextPriority context_priority,
std::string&& qnn_saver_path)
: backend_path_(backend_path),
profiling_level_(profiling_level),
rpc_control_latency_(rpc_control_latency),
htp_performance_mode_(htp_performance_mode),
context_priority_(context_priority),
qnn_saver_path_(qnn_saver_path) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);
Expand Down Expand Up @@ -186,7 +188,6 @@ class QnnBackendManager {
Qnn_LogHandle_t log_handle_ = nullptr;
Qnn_DeviceHandle_t device_handle_ = nullptr;
Qnn_ContextHandle_t context_ = nullptr;
QnnContext_Config_t** context_config_ = nullptr;
ProfilingLevel profiling_level_;
bool backend_initialized_ = false;
bool device_created_ = false;
Expand All @@ -198,6 +199,7 @@ class QnnBackendManager {
std::vector<std::string> op_package_paths_;
uint32_t rpc_control_latency_ = 0;
HtpPerformanceMode htp_performance_mode_;
ContextPriority context_priority_;
std::string sdk_build_version_ = "";
#ifdef _WIN32
std::set<HMODULE> mod_handles_;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ enum class HtpPerformanceMode : uint8_t {
kHtpBalanced,
};

enum class ContextPriority : uint8_t {
LOW = 0,
NORMAL,
NORMAL_HIGH,
HIGH,
UNDEFINED
};

// Defines the graph optimization strategy used by the HTP backend.
enum class HtpGraphFinalizationOptimizationMode : uint8_t {
kDefault = 0,
Expand Down
66 changes: 46 additions & 20 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@
}
}

void QNNExecutionProvider::ParseQnnContextPriority(std::string context_priority_string) {
std::transform(context_priority_string.begin(),

Check warning on line 80 in onnxruntime/core/providers/qnn/qnn_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/qnn_execution_provider.cc#L80

Add #include <algorithm> for transform [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/qnn/qnn_execution_provider.cc:80:  Add #include <algorithm> for transform  [build/include_what_you_use] [4]
context_priority_string.end(),
context_priority_string.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
LOGS_DEFAULT(VERBOSE) << "QNN context priority: " << context_priority_string;
if (context_priority_string == "low") {
context_priority_ = qnn::ContextPriority::LOW;
} else if (context_priority_string == "normal") {
context_priority_ = qnn::ContextPriority::NORMAL;
} else if (context_priority_string == "normal_high") {
context_priority_ = qnn::ContextPriority::NORMAL_HIGH;
} else if (context_priority_string == "high") {
context_priority_ = qnn::ContextPriority::HIGH;
} else {
context_priority_ = qnn::ContextPriority::UNDEFINED;
LOGS_DEFAULT(WARNING) << "QNN context priority: " << context_priority_string << " not valid, set to undefined.";
}
}

void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string) {
LOGS_DEFAULT(VERBOSE) << "HTP graph finalization optimization mode: "
<< htp_graph_finalization_opt_mode_string;
Expand All @@ -96,89 +116,95 @@

QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map,
const SessionOptions* session_options)
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true},
runtime_options_(provider_options_map) {
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} {
if (session_options) {
disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsDisableCPUEPFallback, "0") == "1";
}

static const std::string CONTEXT_CACHE_ENABLED = "qnn_context_cache_enable";
auto context_cache_enabled_pos = runtime_options_.find(CONTEXT_CACHE_ENABLED);
if (context_cache_enabled_pos != runtime_options_.end()) {
auto context_cache_enabled_pos = provider_options_map.find(CONTEXT_CACHE_ENABLED);
if (context_cache_enabled_pos != provider_options_map.end()) {
if (context_cache_enabled_pos->second == "1") {
context_cache_enabled_ = true;
LOGS_DEFAULT(VERBOSE) << "Context cache enabled.";
}
}

static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path";
auto context_cache_path_pos = runtime_options_.find(CONTEXT_CACHE_PATH);
if (context_cache_path_pos != runtime_options_.end()) {
auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH);
if (context_cache_path_pos != provider_options_map.end()) {
context_cache_path_ = context_cache_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_;
}

bool qnn_context_embed_mode = true;
static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode";
auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE);
if (context_cache_embed_mode_pos != runtime_options_.end()) {
auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE);
if (context_cache_embed_mode_pos != provider_options_map.end()) {
qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1";
LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode;
}

static const std::string BACKEND_PATH = "backend_path";
auto backend_path_pos = runtime_options_.find(BACKEND_PATH);
auto backend_path_pos = provider_options_map.find(BACKEND_PATH);

std::string backend_path;
if (backend_path_pos != runtime_options_.end()) {
if (backend_path_pos != provider_options_map.end()) {
backend_path = backend_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path;
} else {
LOGS_DEFAULT(ERROR) << "No backend path provided.";
}

static const std::string PROFILING_LEVEL = "profiling_level";
auto profiling_level_pos = runtime_options_.find(PROFILING_LEVEL);
if (profiling_level_pos != runtime_options_.end()) {
auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL);
if (profiling_level_pos != provider_options_map.end()) {
ParseProfilingLevel(profiling_level_pos->second);
}

static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency";
auto latency_pos = runtime_options_.find(RPC_CONTROL_LANTENCY);
if (latency_pos != runtime_options_.end()) {
auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY);
if (latency_pos != provider_options_map.end()) {
rpc_control_latency_ = static_cast<uint32_t>(std::stoul(latency_pos->second));
LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency_;
}

htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode";
auto htp_performance_mode_pos = runtime_options_.find(HTP_PERFORMANCE_MODE);
if (htp_performance_mode_pos != runtime_options_.end()) {
auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE);
if (htp_performance_mode_pos != provider_options_map.end()) {
ParseHtpPerformanceMode(htp_performance_mode_pos->second);
}

htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
static const std::string HTP_GRAPH_FINALIZATION_OPT_MODE = "htp_graph_finalization_optimization_mode";
auto htp_graph_finalization_opt_mode_pos = runtime_options_.find(HTP_GRAPH_FINALIZATION_OPT_MODE);
if (htp_graph_finalization_opt_mode_pos != runtime_options_.end()) {
auto htp_graph_finalization_opt_mode_pos = provider_options_map.find(HTP_GRAPH_FINALIZATION_OPT_MODE);
if (htp_graph_finalization_opt_mode_pos != provider_options_map.end()) {
ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second);
}

// Enable use of QNN Saver if the user provides a path the QNN Saver backend library.
static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path";
std::string qnn_saver_path;
auto qnn_saver_path_pos = runtime_options_.find(QNN_SAVER_PATH_KEY);
if (qnn_saver_path_pos != runtime_options_.end()) {
auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY);
if (qnn_saver_path_pos != provider_options_map.end()) {
qnn_saver_path = qnn_saver_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path;
}

static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority";
auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY);
if (qnn_context_priority_pos != provider_options_map.end()) {
ParseQnnContextPriority(qnn_context_priority_pos->second);
}

qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
std::move(backend_path),
profiling_level_,
rpc_control_latency_,
htp_performance_mode_,
context_priority_,
std::move(qnn_saver_path));
qnn_cache_model_handler_ = std::make_unique<qnn::QnnCacheModelHandler>(qnn_context_embed_mode);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ class QNNExecutionProvider : public IExecutionProvider {
const logging::Logger& logger);

void ParseHtpPerformanceMode(std::string htp_performance_mode_string);
void ParseQnnContextPriority(std::string context_priority_string);

void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string);

void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const;

private:
ProviderOptions runtime_options_;
qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF;
qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
Expand All @@ -74,6 +74,7 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_cache_path_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
std::unique_ptr<qnn::QnnCacheModelHandler> qnn_cache_model_handler_;
qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL;
};

} // namespace onnxruntime
8 changes: 7 additions & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"

Check warning on line 59 in onnxruntime/test/onnx/main.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/onnx/main.cc#L59

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/onnx/main.cc:59:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n"
"\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n"
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
Expand Down Expand Up @@ -488,6 +489,11 @@
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_performance_mode. select from: " + str);
}
} else if (key == "qnn_context_priority") {
std::set<std::string> supported_qnn_context_priority = {"low", "normal", "normal_high", "high"};
if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) {
ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high");
}
} else if (key == "qnn_saver_path") {
// no validation
} else if (key == "htp_graph_finalization_optimization_mode") {
Expand All @@ -502,7 +508,7 @@
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode'])");
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
"\t-A: Disable memory arena\n"
"\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n"
"\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n"
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "
"'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. "
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "

Check warning on line 37 in onnxruntime/test/perftest/command_args_parser.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/perftest/command_args_parser.cc#L37

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/perftest/command_args_parser.cc:37:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. "
"Default:'cpu'.\n"
"\t-b [tf|ort]: backend to use. Default:ort\n"
"\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n"
Expand Down Expand Up @@ -71,6 +71,7 @@
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"

Check warning on line 74 in onnxruntime/test/perftest/command_args_parser.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/perftest/command_args_parser.cc#L74

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/perftest/command_args_parser.cc:74:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
"\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n"
"\t '0', '1', '2', '3', default is '0'.\n"
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,15 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str);
}
} else if (key == "qnn_context_priority") {
std::set<std::string> supported_qnn_context_priority = {"low", "normal", "normal_high", "high"};
if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) {
ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high");
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode'])");
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down
Loading
Loading