Skip to content

Commit

Permalink
[QNN EP] Session option for graph optimization (#18262)
Browse files Browse the repository at this point in the history
### Description
Adds the QNN session option `htp_graph_finalization_optimization_mode`
to enable QNN graph optimizations at the expense of longer preparation
time.

### Motivation and Context
Allow enabling QNN graph optimizations per app/model.
  • Loading branch information
adrianlizarraga authored Nov 8, 2023
1 parent c8def0c commit a0eeeaf
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 37 deletions.
9 changes: 7 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3598,12 +3598,17 @@ struct OrtApi {
* "rpc_control_latency": QNN RPC control latency.
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the Onnx skeleton model.
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model.
* 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context.
* The path is relative path to the Onnx skeleton model file.
* The path is relative path to the ONNX skeleton model file.
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
* may alter model/EP partitioning. Use only for debugging.
* "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options:
* - "0": Default.
* - "1": Faster preparation time, less optimal graph.
* - "2": Longer preparation time, more optimal graph.
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
42 changes: 20 additions & 22 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,30 @@ Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewe
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
qnn_model);
} else {
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
}

std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +
"/" + external_qnn_context_binary_file_name);
size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");

cache_file.seekg(0, cache_file.end);
buffer_size = static_cast<size_t>(cache_file.tellg());
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +
"/" + external_qnn_context_binary_file_name);
size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");

cache_file.seekg(0, cache_file.beg);
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
// Load file into buffer
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
cache_file.close();
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
}
cache_file.seekg(0, cache_file.end);
buffer_size = static_cast<size_t>(cache_file.tellg());
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");

return Status::OK();
cache_file.seekg(0, cache_file.beg);
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
// Load file into buffer
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
cache_file.close();
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
}

Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
Expand Down
35 changes: 29 additions & 6 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,7 @@ Status QnnBackendManager::SetHtpPowerConfig() {
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
// Get power client id
uint32_t powerconfig_client_id = 0;
status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &powerconfig_client_id);
status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_);
ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed.");

constexpr const int kNumConfigs = 1;
Expand All @@ -580,7 +579,7 @@ Status QnnBackendManager::SetHtpPowerConfig() {
QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0];
dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3;
QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config;
dcvs_v3.contextId = powerconfig_client_id;
dcvs_v3.contextId = htp_power_config_client_id_;
dcvs_v3.setSleepDisable = 0;
dcvs_v3.sleepDisable = 0;
dcvs_v3.setDcvsEnable = 1;
Expand Down Expand Up @@ -678,7 +677,7 @@ Status QnnBackendManager::SetHtpPowerConfig() {
break;
}
std::vector<const QnnHtpPerfInfrastructure_PowerConfig_t*> perf_power_configs_ptr_ = ObtainNullTermPtrVector(power_configs);
status = htp_perf_infra.setPowerConfig(powerconfig_client_id, perf_power_configs_ptr_.data());
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr_.data());
ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode.");

// Set rpc control latency here, but note that v68 doesn't support rpc polling mode.
Expand All @@ -692,7 +691,7 @@ Status QnnBackendManager::SetHtpPowerConfig() {
rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME;
rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_;
perf_power_configs_ptr_ = ObtainNullTermPtrVector(rpc_power_configs);
status = htp_perf_infra.setPowerConfig(powerconfig_client_id, perf_power_configs_ptr_.data());
status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr_.data());
ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency.");
}

Expand All @@ -713,12 +712,36 @@ void QnnBackendManager::Split(std::vector<std::string>& split_string,
}
}

Status QnnBackendManager::DestroyHTPPowerConfigID() {
if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) {
return Status::OK();
}

QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");

auto* htp_infra = static_cast<QnnHtpDevice_Infrastructure_t*>(qnn_device_infra);
ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType,
"HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;

Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_);
ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed.");
return Status::OK();
}

void QnnBackendManager::ReleaseResources() {
if (!backend_setup_completed_) {
return;
}

auto result = ReleaseContext();
auto result = DestroyHTPPowerConfigID();
if (Status::OK() != result) {
ORT_THROW("Failed to DestroyHTPPowerConfigID.");
}

result = ReleaseContext();
if (Status::OK() != result) {
ORT_THROW("Failed to ReleaseContext.");
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class QnnBackendManager {

Status UnloadLib(void* handle);

Status DestroyHTPPowerConfigID();

void* LibFunction(void* handle, const char* symbol, std::string& error_msg);

template <class T>
Expand Down Expand Up @@ -201,6 +203,7 @@ class QnnBackendManager {
std::set<HMODULE> mod_handles_;
#endif
const std::string qnn_saver_path_;
uint32_t htp_power_config_client_id_ = 0;
};

} // namespace qnn
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ enum class HtpPerformanceMode : uint8_t {
kHtpBalanced,
};

// Defines the graph optimization strategy used by the HTP backend.
enum class HtpGraphFinalizationOptimizationMode : uint8_t {
kDefault = 0,
kMode1 = 1, // Faster preparation time, less optimal graph
kMode2 = 2, // Longer preparation time, more optimal graph
kMode3 = 3, // Longest preparation time, most likely even more optimal graph.
};

enum class QnnBackendType : uint8_t {
CPU = 0,
GPU,
Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/qnn/builder/qnn_graph_configs_helper.h"

#include "HTP/QnnHtpGraph.h"

namespace onnxruntime {
namespace qnn {

const QnnGraph_Config_t** QnnGraphConfigsBuilder::GetQnnGraphConfigs() {
if (graph_config_ptrs_.empty()) {
return nullptr;
}

if (!IsNullTerminated()) {
graph_config_ptrs_.push_back(nullptr);
}

return graph_config_ptrs_.data();
}

QnnHtpGraph_CustomConfig_t& QnnGraphConfigsBuilder::PushHtpGraphCustomConfig() {
htp_custom_graph_configs_.push_back(QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT);
return htp_custom_graph_configs_.back();
}

QnnGraph_Config_t& QnnGraphConfigsBuilder::PushGraphConfig() {
graph_configs_.push_back(QNN_GRAPH_CONFIG_INIT);
QnnGraph_Config_t& config = graph_configs_.back();

// Add pointer to this new graph config to the list of graph config pointers.
if (IsNullTerminated()) {
graph_config_ptrs_.back() = &config; // Replace last nullptr entry.
} else {
graph_config_ptrs_.push_back(&config);
}

return config;
}

} // namespace qnn
} // namespace onnxruntime
56 changes: 56 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <core/common/inlined_containers_fwd.h>

#include "HTP/QnnHtpGraph.h"

namespace onnxruntime {
namespace qnn {

/**
* Helper class for building a null-terminated list of QNN Graph configurations.
* A QNN configuration consists of multiple objects with references to each other. This
* class ensures that all configuration objects have the same lifetime, so that they remain valid
* across the call to graphCreate().
*/
class QnnGraphConfigsBuilder {
public:
/**
* Returns a pointer to the beginning of a null-terminated array of QNN Graph configurations.
* This result is passed QNN's graphCreate() API.
*
* \return Pointer to null-terminated QnnGraph_Config_t* array.
*/
const QnnGraph_Config_t** GetQnnGraphConfigs();

/**
* Creates and returns a reference to a new HTP graph configuration object. The object is initialized to
* the QNN recommended default value. The caller is meant to override fields in this object.
*
* \return A reference to a default QnnHtpGraph_CustomConfig_t object.
*/
QnnHtpGraph_CustomConfig_t& PushHtpGraphCustomConfig();

/**
* Creates and returns a reference to a new graph configuration object. The object is initialized to
* the QNN recommended default value. The caller is meant to override fields in this object.
*
* \return A reference to a default QnnGraph_Config_t object.
*/
QnnGraph_Config_t& PushGraphConfig();

private:
bool IsNullTerminated() const {
return !graph_config_ptrs_.empty() && graph_config_ptrs_.back() == nullptr;
}

InlinedVector<QnnHtpGraph_CustomConfig_t> htp_custom_graph_configs_;
InlinedVector<QnnGraph_Config_t> graph_configs_;
InlinedVector<const QnnGraph_Config_t*> graph_config_ptrs_;
};

} // namespace qnn
} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node,
}

Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node) {
const onnxruntime::Node& fused_node,
const QnnGraph_Config_t** graph_configs) {
LOGS(logger_, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name();

// Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is
Expand All @@ -107,7 +108,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
initializer_inputs_,
qnn_backend_manager_->GetQnnBackendType());
bool rt = true;
rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name);
rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name, graph_configs);
if (!rt) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper.");
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class QnnModel {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModel);

Status ComposeGraph(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node);
const onnxruntime::Node& fused_node,
const QnnGraph_Config_t** graph_configs = nullptr);

Status FinalizeGraphs();

Expand Down
44 changes: 43 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_m
}
}

void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string) {
LOGS_DEFAULT(VERBOSE) << "HTP graph finalization optimization mode: "
<< htp_graph_finalization_opt_mode_string;

if (htp_graph_finalization_opt_mode_string.empty() || htp_graph_finalization_opt_mode_string == "0") {
htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
} else if (htp_graph_finalization_opt_mode_string == "1") {
htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kMode1;
} else if (htp_graph_finalization_opt_mode_string == "2") {
htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kMode2;
} else if (htp_graph_finalization_opt_mode_string == "3") {
htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kMode3;
} else {
LOGS_DEFAULT(WARNING) << "Invalid HTP graph finalization optimization mode: "
<< htp_graph_finalization_opt_mode_string;
}
}

QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map,
const SessionOptions* session_options)
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true},
Expand Down Expand Up @@ -140,6 +158,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
ParseHtpPerformanceMode(htp_performance_mode_pos->second);
}

htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
static const std::string HTP_GRAPH_FINALIZATION_OPT_MODE = "htp_graph_finalization_optimization_mode";
auto htp_graph_finalization_opt_mode_pos = runtime_options_.find(HTP_GRAPH_FINALIZATION_OPT_MODE);
if (htp_graph_finalization_opt_mode_pos != runtime_options_.end()) {
ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second);
}

// Enable use of QNN Saver if the user provides a path the QNN Saver backend library.
static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path";
std::string qnn_saver_path;
Expand Down Expand Up @@ -448,6 +473,20 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& nod
return Status::OK();
}

void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const {
if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP &&
htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) {
QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig();
htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION;
htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG;
htp_graph_opt_config.optimizationOption.floatValue = static_cast<float>(htp_graph_finalization_opt_mode_);

QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig();
graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_opt_config.customConfig = &htp_graph_opt_config;
}
}

Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs,
const logging::Logger& logger) {
Expand All @@ -458,7 +497,10 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger,
qnn_backend_manager_.get());

ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node));
qnn::QnnGraphConfigsBuilder graph_configs_builder;
InitQnnGraphConfigs(graph_configs_builder);

ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnGraphConfigs()));
ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs());
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());

Expand Down
Loading

0 comments on commit a0eeeaf

Please sign in to comment.