Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-partition support for context binary cache feature #18865

Merged
merged 32 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
eba7e03
Support multi-partition for context cache feature
HectorSVC Dec 14, 2023
8ffa12e
2. Load and execute the model with multiple EPContext
HectorSVC Dec 15, 2023
8117368
3. Mode: run with QDQ model + QNN context model
HectorSVC Dec 16, 2023
8a00784
Remove QNN EP options: qnn_context_cache_enable, qnn_context_cache_pa…
HectorSVC Dec 18, 2023
16058a8
update test code
HectorSVC Dec 18, 2023
aacab16
Update test code to reflect the changes which move provider options t…
HectorSVC Dec 19, 2023
a5a9aef
Merge branch 'main' of https://github.com/microsoft/onnxruntime
HectorSVC Dec 20, 2023
1567bf2
merge main
HectorSVC Dec 20, 2023
22b4c93
Fix Linux build
HectorSVC Dec 27, 2023
de53da1
fix some build issues
HectorSVC Dec 29, 2023
c3883b1
Set inputs outputs explicitly to make sure the order is same as the u…
HectorSVC Jan 18, 2024
a457b70
Merge branch 'main' into qnn_ctx_multi_partition_support
HectorSVC Jan 19, 2024
30c1ed7
resolve conflict
HectorSVC Jan 20, 2024
55d10b2
resolved merge conflicts
HectorSVC Jan 21, 2024
ce3c64f
resolve merge conflicts
HectorSVC Jan 21, 2024
8c55f19
remove the validation mode
HectorSVC Jan 22, 2024
e7c0827
clean up some not used code
HectorSVC Jan 22, 2024
d3feaa4
renaming
HectorSVC Jan 22, 2024
33516cd
Update tests
HectorSVC Jan 23, 2024
445bc1b
fix the issue relate to initializer handling
HectorSVC Jan 25, 2024
9c7bdfc
Move QNN context cache related tests to a separate file
HectorSVC Jan 25, 2024
3dfd94b
rename some tests
HectorSVC Jan 25, 2024
3b8e879
Add UT to verify the multi-partition support
HectorSVC Jan 26, 2024
ff2c313
fill some gaps in UT and fix an issue relate to context cache path
HectorSVC Jan 26, 2024
c8ea83d
add one more UT to dumps the context cache model with 2 EPContext nodes
HectorSVC Jan 26, 2024
3dbb95d
formating
HectorSVC Jan 26, 2024
9eb32aa
Use ContribOp FusedMatMul instead of Add op to make sure the float32 …
HectorSVC Jan 26, 2024
1d4fa6f
Merge branch 'main' into qnn_ctx_multi_partition_support
HectorSVC Jan 26, 2024
74c5cef
update according review comments.
HectorSVC Feb 1, 2024
0137172
update according review comments
HectorSVC Feb 1, 2024
b79a256
formating
HectorSVC Feb 1, 2024
9e71147
revert a change
HectorSVC Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 128 additions & 64 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,92 @@
namespace onnxruntime {
namespace qnn {

Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model) {
is_qnn_ctx_model = false;
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it has a node with EPContext type
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
return true;
}
}
return false;
}

bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs) {
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
int count = 0;
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
is_qnn_ctx_model = true;
}
++count;
bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer);
if (has_qnn_ep_context_node) {
return true;
}
}
return false;
}

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
main_context_pos = -1;
for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) {
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph);
const auto& ep_context_node = graph_viewer.Nodes().begin();
ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node.");
qnn_models.emplace(ep_context_node->Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
NodeAttrHelper node_helper(*ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
main_context_pos = static_cast<int>(i);
}
ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node.");
}

ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1");
return Status::OK();
}

bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
return true;
Status GetContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
const Node& fused_node = fused_node_and_graph.fused_node;
qnn_models.emplace(fused_node.Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
}
using namespace onnxruntime;

Check warning on line 69 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L69

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:69:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());

for (const auto& ep_context_node : graph.Nodes()) {
if (EPCONTEXT_OP != ep_context_node.OpType()) {
continue;
}
NodeAttrHelper node_helper(ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
return GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models);
}
}
return false;

return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to find EPContext node with main_context=1.");
}

Status LoadContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetContextFromOnnxModel(fused_nodes_and_graphs, ctx_onnx_model_path, qnn_backend_manager, logger, qnn_models);

Check warning on line 93 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L93

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:93:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Status CreateNodeArgs(const std::vector<std::string>& names,
Expand All @@ -60,32 +120,18 @@
return Status::OK();
}

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
return GetEpContextFromGraph(GraphViewer(graph),
ctx_onnx_model_path,
qnn_backend_manager,
qnn_model);
}

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
qnn_model);
qnn_models);
}

std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
Expand Down Expand Up @@ -133,23 +179,16 @@
cache_file.close();
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
qnn_models);
}

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Status status;
if (is_qnn_ctx_model) {
status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);

Check warning on line 189 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L189

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:189:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}
Expand All @@ -160,19 +199,24 @@
Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::vector<std::string>& graph_partition_names,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
model_name = graph.Name();
model_description = graph.Description();
graph_partition_name = node_helper.Get(PARTITION_NAME, "");
cache_source = node_helper.Get(SOURCE, "");

for (const auto& ep_context_node : graph.Nodes()) {
if (EPCONTEXT_OP != ep_context_node.OpType()) {
continue;
}
NodeAttrHelper node_helper(ep_context_node);
cache_source = node_helper.Get(SOURCE, "");
graph_partition_names.push_back(node_helper.Get(PARTITION_NAME, ""));
}

return Status::OK();
}
Expand All @@ -190,23 +234,24 @@
return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
Status ValidateWithContextFile(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger) {
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::vector<std::string> graph_partition_names;
std::string cache_source;

auto status = GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
graph_partition_names,
cache_source,
logger);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContext model.");
}

// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
Expand All @@ -215,15 +260,34 @@
return Status::OK();
}

bool partition_names_matched = true;
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const Node& fused_node = fused_node_graph.fused_node;
const std::string& graph_meta_id = fused_node.Name();
bool name_found = false;
for (auto partition_name_from_ctx : graph_partition_names) {
if (partition_name_from_ctx == graph_meta_id) {
name_found = true;
break;
}
}

if (!name_found) {
LOGS(logger, ERROR) << "Partition meta_id not found from any EPContext node: " << graph_meta_id;
partition_names_matched = false;
break;
}
}

if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
!partition_names_matched) {
std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ",
model_name, " ", model_description, " ", graph_partition_name,
model_name, " ", model_description,
" vs epcontext: ",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
model_description_from_ctx_cache,
" or the partition name not match.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message);
}

Expand Down
47 changes: 28 additions & 19 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";

Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model);
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);

bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer);
bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs);

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
Expand All @@ -42,34 +47,38 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path);

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model);
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status GetContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status LoadContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
Status ValidateWithContextFile(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger);

Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::vector<std::string>& graph_partition_names,
std::string& cache_source,
const logging::Logger& logger);

Expand Down
Loading
Loading