Skip to content

Commit

Permalink
[QNN EP] Enable option to set QNN context priority (#18315)
Browse files Browse the repository at this point in the history
Enable option qnn_context_priority to set QNN context priority, options:
"low", "normal", "normal_high", "high".

### Description
Enable option qnn_context_priority to set QNN context priority, options:
"low", "normal", "normal_high", "high".

This feature guarantees the model inference with higher priority. Tested
with onnxruntime_perf_test tool using same model.
1. Run the model on the NPU with single instance, the latency is 300ms.
2. Run the same model on NPU with 2 instance at same time.
   Case 1:   
   both with same priority (high ) -- latency is 600ms
   Case 2:   
   1 with low priority -- latency is 30,000ms
   1 with high priority --  latency is 300ms
   Case 3:   
   1 with normal priority -- latency is 15,000ms
   1 with high priority --  latency is 300ms
  • Loading branch information
HectorSVC authored Nov 9, 2023
1 parent 7a3da45 commit 55c19d6
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 29 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3604,6 +3604,7 @@ struct OrtApi {
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
* may alter model/EP partitioning. Use only for debugging.
* "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal".
* "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options:
* - "0": Default.
* - "1": Faster preparation time, less optimal graph.
Expand Down
42 changes: 40 additions & 2 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,48 @@ Status QnnBackendManager::ReleaseProfilehandle() {
return Status::OK();
}

Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t& qnn_context_config) {
qnn_context_config.option = QNN_CONTEXT_CONFIG_OPTION_PRIORITY;
switch (context_priority) {
case ContextPriority::LOW: {
qnn_context_config.priority = QNN_PRIORITY_LOW;
break;
}
case ContextPriority::NORMAL: {
qnn_context_config.priority = QNN_PRIORITY_NORMAL;
break;
}
case ContextPriority::NORMAL_HIGH: {
qnn_context_config.priority = QNN_PRIORITY_NORMAL_HIGH;
break;
}
case ContextPriority::HIGH: {
qnn_context_config.priority = QNN_PRIORITY_HIGH;
break;
}
case ContextPriority::UNDEFINED: {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Qnn context priority.");
}
default:
qnn_context_config.priority = QNN_PRIORITY_NORMAL;
} // switch

return Status::OK();
}

Status QnnBackendManager::CreateContext() {
if (true == context_created_) {
LOGS_DEFAULT(INFO) << "Context created already.";
return Status::OK();
}

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

auto result = qnn_interface_.contextCreate(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
context_configs,
&context_);

ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context.");
Expand Down Expand Up @@ -486,9 +519,14 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
(const QnnContext_Config_t**)&context_config_,
context_configs,
static_cast<void*>(buffer),
buffer_length,
&context_,
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ class QnnBackendManager {
ProfilingLevel profiling_level,
uint32_t rpc_control_latency,
HtpPerformanceMode htp_performance_mode,
ContextPriority context_priority,
std::string&& qnn_saver_path)
: backend_path_(backend_path),
profiling_level_(profiling_level),
rpc_control_latency_(rpc_control_latency),
htp_performance_mode_(htp_performance_mode),
context_priority_(context_priority),
qnn_saver_path_(qnn_saver_path) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);
Expand Down Expand Up @@ -186,7 +188,6 @@ class QnnBackendManager {
Qnn_LogHandle_t log_handle_ = nullptr;
Qnn_DeviceHandle_t device_handle_ = nullptr;
Qnn_ContextHandle_t context_ = nullptr;
QnnContext_Config_t** context_config_ = nullptr;
ProfilingLevel profiling_level_;
bool backend_initialized_ = false;
bool device_created_ = false;
Expand All @@ -198,6 +199,7 @@ class QnnBackendManager {
std::vector<std::string> op_package_paths_;
uint32_t rpc_control_latency_ = 0;
HtpPerformanceMode htp_performance_mode_;
ContextPriority context_priority_;
std::string sdk_build_version_ = "";
#ifdef _WIN32
std::set<HMODULE> mod_handles_;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ enum class HtpPerformanceMode : uint8_t {
kHtpBalanced,
};

enum class ContextPriority : uint8_t {
LOW = 0,
NORMAL,
NORMAL_HIGH,
HIGH,
UNDEFINED
};

// Defines the graph optimization strategy used by the HTP backend.
enum class HtpGraphFinalizationOptimizationMode : uint8_t {
kDefault = 0,
Expand Down
66 changes: 46 additions & 20 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_m
}
}

void QNNExecutionProvider::ParseQnnContextPriority(std::string context_priority_string) {
std::transform(context_priority_string.begin(),
context_priority_string.end(),
context_priority_string.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
LOGS_DEFAULT(VERBOSE) << "QNN context priority: " << context_priority_string;
if (context_priority_string == "low") {
context_priority_ = qnn::ContextPriority::LOW;
} else if (context_priority_string == "normal") {
context_priority_ = qnn::ContextPriority::NORMAL;
} else if (context_priority_string == "normal_high") {
context_priority_ = qnn::ContextPriority::NORMAL_HIGH;
} else if (context_priority_string == "high") {
context_priority_ = qnn::ContextPriority::HIGH;
} else {
context_priority_ = qnn::ContextPriority::UNDEFINED;
LOGS_DEFAULT(WARNING) << "QNN context priority: " << context_priority_string << " not valid, set to undefined.";
}
}

void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string) {
LOGS_DEFAULT(VERBOSE) << "HTP graph finalization optimization mode: "
<< htp_graph_finalization_opt_mode_string;
Expand All @@ -96,89 +116,95 @@ void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std::

QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map,
const SessionOptions* session_options)
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true},
runtime_options_(provider_options_map) {
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} {
if (session_options) {
disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsDisableCPUEPFallback, "0") == "1";
}

static const std::string CONTEXT_CACHE_ENABLED = "qnn_context_cache_enable";
auto context_cache_enabled_pos = runtime_options_.find(CONTEXT_CACHE_ENABLED);
if (context_cache_enabled_pos != runtime_options_.end()) {
auto context_cache_enabled_pos = provider_options_map.find(CONTEXT_CACHE_ENABLED);
if (context_cache_enabled_pos != provider_options_map.end()) {
if (context_cache_enabled_pos->second == "1") {
context_cache_enabled_ = true;
LOGS_DEFAULT(VERBOSE) << "Context cache enabled.";
}
}

static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path";
auto context_cache_path_pos = runtime_options_.find(CONTEXT_CACHE_PATH);
if (context_cache_path_pos != runtime_options_.end()) {
auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH);
if (context_cache_path_pos != provider_options_map.end()) {
context_cache_path_ = context_cache_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_;
}

bool qnn_context_embed_mode = true;
static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode";
auto context_cache_embed_mode_pos = runtime_options_.find(CONTEXT_CACHE_EMBED_MODE);
if (context_cache_embed_mode_pos != runtime_options_.end()) {
auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE);
if (context_cache_embed_mode_pos != provider_options_map.end()) {
qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1";
LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode;
}

static const std::string BACKEND_PATH = "backend_path";
auto backend_path_pos = runtime_options_.find(BACKEND_PATH);
auto backend_path_pos = provider_options_map.find(BACKEND_PATH);

std::string backend_path;
if (backend_path_pos != runtime_options_.end()) {
if (backend_path_pos != provider_options_map.end()) {
backend_path = backend_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path;
} else {
LOGS_DEFAULT(ERROR) << "No backend path provided.";
}

static const std::string PROFILING_LEVEL = "profiling_level";
auto profiling_level_pos = runtime_options_.find(PROFILING_LEVEL);
if (profiling_level_pos != runtime_options_.end()) {
auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL);
if (profiling_level_pos != provider_options_map.end()) {
ParseProfilingLevel(profiling_level_pos->second);
}

static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency";
auto latency_pos = runtime_options_.find(RPC_CONTROL_LANTENCY);
if (latency_pos != runtime_options_.end()) {
auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY);
if (latency_pos != provider_options_map.end()) {
rpc_control_latency_ = static_cast<uint32_t>(std::stoul(latency_pos->second));
LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency_;
}

htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode";
auto htp_performance_mode_pos = runtime_options_.find(HTP_PERFORMANCE_MODE);
if (htp_performance_mode_pos != runtime_options_.end()) {
auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE);
if (htp_performance_mode_pos != provider_options_map.end()) {
ParseHtpPerformanceMode(htp_performance_mode_pos->second);
}

htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
static const std::string HTP_GRAPH_FINALIZATION_OPT_MODE = "htp_graph_finalization_optimization_mode";
auto htp_graph_finalization_opt_mode_pos = runtime_options_.find(HTP_GRAPH_FINALIZATION_OPT_MODE);
if (htp_graph_finalization_opt_mode_pos != runtime_options_.end()) {
auto htp_graph_finalization_opt_mode_pos = provider_options_map.find(HTP_GRAPH_FINALIZATION_OPT_MODE);
if (htp_graph_finalization_opt_mode_pos != provider_options_map.end()) {
ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second);
}

// Enable use of QNN Saver if the user provides a path the QNN Saver backend library.
static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path";
std::string qnn_saver_path;
auto qnn_saver_path_pos = runtime_options_.find(QNN_SAVER_PATH_KEY);
if (qnn_saver_path_pos != runtime_options_.end()) {
auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY);
if (qnn_saver_path_pos != provider_options_map.end()) {
qnn_saver_path = qnn_saver_path_pos->second;
LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path;
}

static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority";
auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY);
if (qnn_context_priority_pos != provider_options_map.end()) {
ParseQnnContextPriority(qnn_context_priority_pos->second);
}

qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
std::move(backend_path),
profiling_level_,
rpc_control_latency_,
htp_performance_mode_,
context_priority_,
std::move(qnn_saver_path));
qnn_cache_model_handler_ = std::make_unique<qnn::QnnCacheModelHandler>(qnn_context_embed_mode);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ class QNNExecutionProvider : public IExecutionProvider {
const logging::Logger& logger);

void ParseHtpPerformanceMode(std::string htp_performance_mode_string);
void ParseQnnContextPriority(std::string context_priority_string);

void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string);

void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const;

private:
ProviderOptions runtime_options_;
qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF;
qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
Expand All @@ -74,6 +74,7 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_cache_path_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
std::unique_ptr<qnn::QnnCacheModelHandler> qnn_cache_model_handler_;
qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL;
};

} // namespace onnxruntime
8 changes: 7 additions & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void usage() {
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"
"\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n"
"\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n"
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
Expand Down Expand Up @@ -488,6 +489,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_performance_mode. select from: " + str);
}
} else if (key == "qnn_context_priority") {
std::set<std::string> supported_qnn_context_priority = {"low", "normal", "normal_high", "high"};
if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) {
ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high");
}
} else if (key == "qnn_saver_path") {
// no validation
} else if (key == "htp_graph_finalization_optimization_mode") {
Expand All @@ -502,7 +508,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode'])");
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ namespace perftest {
"\t-A: Disable memory arena\n"
"\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n"
"\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n"
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "
"'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. "
"\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', "
"'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. "
"Default:'cpu'.\n"
"\t-b [tf|ort]: backend to use. Default:ort\n"
"\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n"
Expand Down Expand Up @@ -71,6 +71,7 @@ namespace perftest {
"\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
"\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n"
"\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n"
"\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n"
"\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n"
"\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n"
"\t '0', '1', '2', '3', default is '0'.\n"
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,15 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str);
}
} else if (key == "qnn_context_priority") {
std::set<std::string> supported_qnn_context_priority = {"low", "normal", "normal_high", "high"};
if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) {
ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high");
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable',
'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode', 'qnn_saver_path',
'htp_graph_finalization_optimization_mode'])");
'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])");
}

qnn_options[key] = value;
Expand Down
Loading

0 comments on commit 55c19d6

Please sign in to comment.