Skip to content

Commit

Permalink
Enable QNN HTP spill fill buffer setting to save RAM usage. (microsof…
Browse files Browse the repository at this point in the history
…t#22853)

### Description
Enable QNN HTP spill fill buffer setting to save RAM usage.
This feature is available after QNN 2.28. Need to re-generate QNN
context binary.

https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_backend.html#qnn-htp-backend-api

Requirements:
1. Need to re-generate the Onnx model with QNN context binary by set the
EP option enable_htp_spill_fill_buffer = 1.
2. Works for a model with multiple Context binaries. Need manually merge
2 Onnx model with context binary into 1 Onnx model.
3. Requires Linux platform if generate the context binary offline since
QnnSystem lib is not available for Windows x86_64 platform.
No need to do extra thing while running the model inference.

The generated EPContext node will have a max_size attribute with the
maximum spill fill buffer size for the context binary
<img width="353" alt="image"
src="https://github.com/user-attachments/assets/a3bf48be-a8da-4381-8a1d-3f2558eea37d">

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent 9bf04ec commit ddb6e65
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 50 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>(Optional) Hardware architecture.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>max_size</tt> : int</dt>
<dd>max size in the context. Usage depend on the EP.</dd>
<dt><tt>notes</tt> : string</dt>
<dd>(Optional) Some notes for the model</dd>
<dt><tt>onnx_model_filename</tt> : string</dt>
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3667,6 +3667,9 @@ struct OrtApi {
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
* "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
* - "0": Default. Disabled.
* - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3335,6 +3335,11 @@ void RegisterContribSchemas() {
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
.Attr(
"max_size",
"max size in the context. Usage depend on the EP.",
AttributeProto::INT,
static_cast<int64_t>(0))
.AllowUncheckedAttributes()
.Input(
0,
Expand Down
43 changes: 38 additions & 5 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models) {
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
Expand All @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
main_context_node.Name(),
qnn_models);
qnn_models,
max_spill_fill_size);
}

std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
Expand Down Expand Up @@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
main_context_node.Name(),
qnn_models);
qnn_models,
max_spill_fill_size);
}

Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
uint32_t total_context_size,
int64_t& max_spill_fill_size,
std::vector<int>& main_context_pos_list) {
max_spill_fill_size = 0;
int max_size_index = 0;
for (uint32_t i = 0; i < total_context_size; ++i) {
auto index = main_context_pos_list[i];
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph);
ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*ep_context_node);
int64_t max_size = node_helper.Get(MAX_SIZE, static_cast<int64_t>(0));
if (max_size > max_spill_fill_size) {
max_spill_fill_size = max_size;
max_size_index = i;
}
}
if (0 != max_size_index) {
int tmp_index = main_context_pos_list[0];
main_context_pos_list[0] = main_context_pos_list[max_size_index];
main_context_pos_list[max_size_index] = tmp_index;
}

return Status::OK();
}

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger) {
const logging::Logger& logger,
int64_t max_spill_fill_size) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
qnn_models);
qnn_models, max_spill_fill_size);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
Expand Down Expand Up @@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger) {
auto& graph = model->MainGraph();

Expand Down Expand Up @@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model,
}
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
ep_node.AddAttribute(MAX_SIZE, static_cast<int64_t>(max_spill_fill_buffer_size));
}
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
static const std::string MAX_SIZE = "max_size";

bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);

Expand All @@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models);
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size);

Status TryGetMaxSpillFillSize(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
uint32_t total_context_size,
int64_t& max_spill_fill_size,
std::vector<int>& main_context_pos_list);

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger);
const logging::Logger& logger,
int64_t max_spill_fill_size);

Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
Expand All @@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime
127 changes: 100 additions & 27 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
Expand Down Expand Up @@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() {
}

QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
QnnHtpContext_CustomConfig_t customConfig;
customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
context_config_weight_sharing.customConfig = &customConfig;
context_config_weight_sharing.customConfig = &custom_config;

QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
Expand Down Expand Up @@ -615,9 +616,71 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
return context_buffer;
}

Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
uint64_t buffer_length,
uint64_t& max_spill_fill_buffer_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
ORT_RETURN_IF(result, "Failed to get valid function pointer.");

QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");

const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
Qnn_ContextBinarySize_t binary_info_size{0};
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
static_cast<void*>(buffer),
buffer_length,
&binary_info,
&binary_info_size);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");

// binary_info life cycle is here
// Binary info to graph info
// retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
graph_count = binary_info->contextBinaryInfoV3.numGraphs;
graphs_info = binary_info->contextBinaryInfoV3.graphs;
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
graph_count = binary_info->contextBinaryInfoV2.numGraphs;
graphs_info = binary_info->contextBinaryInfoV2.graphs;
} else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
graph_count = binary_info->contextBinaryInfoV1.numGraphs;
graphs_info = binary_info->contextBinaryInfoV1.graphs;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
}

for (uint32_t i = 0; i < graph_count; ++i) {
if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
auto htp_graph_info = reinterpret_cast<QnnHtpSystemContext_GraphBlobInfo_t*>(graphs_info[i].graphInfoV3.graphBlobInfo);
if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size;
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
}
} else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
}
}

LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed.";
return Status::OK();
}

Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
QnnModelLookupTable& qnn_models) {
QnnModelLookupTable& qnn_models,
int64_t max_spill_fill_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
Expand All @@ -638,7 +701,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

// binary_info life cycle is here
// Binary info to graph info
// retrieve Qnn graph infor from binary info
// retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
Expand All @@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

// Register spill fill buffer for multi context
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;

// The spill fill buffer is available since 2.28, API version starts from 2.21
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
QnnHtpContext_GroupRegistration_t group_info;
size_t current_contexts_size = GetQnnContextSize();
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
// note that we already move the context with max spill fill size to the beginning of the list
group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
custom_config.groupRegistration = group_info;
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
spill_fill_config.customConfig = &custom_config;
#endif
QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;

const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
Qnn_ContextHandle_t context = nullptr;
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
Expand All @@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
buffer_length,
&context,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
if (1 == graph_count) {
// in case the EPContext node is generated from script
Expand All @@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
return Status::OK();
}

Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
// need to load system lib if load from Qnn context binary
// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
bool load_from_cached_context,
bool need_load_system_lib) {
std::lock_guard<std::mutex> lock(logger_mutex_);
if (backend_setup_completed_) {
LOGS(logger, VERBOSE) << "Backend setup already!";
Expand All @@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_

LOGS(logger, VERBOSE) << "LoadBackend succeed.";

if (load_from_cached_context) {
if (load_from_cached_context || need_load_system_lib) {
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
}

Expand Down Expand Up @@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
return Status::OK();
}

void QnnBackendManager::Split(std::vector<std::string>& split_string,
const std::string& tokenized_string,
const char separator) {
split_string.clear();
std::istringstream tokenized_string_stream(tokenized_string);
while (!tokenized_string_stream.eof()) {
std::string value;
getline(tokenized_string_stream, value, separator);
if (!value.empty()) {
split_string.push_back(value);
}
}
}

Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
Expand Down
Loading

0 comments on commit ddb6e65

Please sign in to comment.