Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-partition support for context binary cache feature #18865

Merged
merged 32 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
eba7e03
Support multi-partition for context cache feature
HectorSVC Dec 14, 2023
8ffa12e
2. Load and execute the model with multiple EPContext
HectorSVC Dec 15, 2023
8117368
3. Mode: run with QDQ model + QNN context model
HectorSVC Dec 16, 2023
8a00784
Remove QNN EP options: qnn_context_cache_enable, qnn_context_cache_pa…
HectorSVC Dec 18, 2023
16058a8
update test code
HectorSVC Dec 18, 2023
aacab16
Update test code to reflect the changes which move provider options t…
HectorSVC Dec 19, 2023
a5a9aef
Merge branch 'main' of https://github.com/microsoft/onnxruntime
HectorSVC Dec 20, 2023
1567bf2
merge main
HectorSVC Dec 20, 2023
22b4c93
Fix Linux build
HectorSVC Dec 27, 2023
de53da1
fix some build issues
HectorSVC Dec 29, 2023
c3883b1
Set inputs outputs explicitly to make sure the order is same as the u…
HectorSVC Jan 18, 2024
a457b70
Merge branch 'main' into qnn_ctx_multi_partition_support
HectorSVC Jan 19, 2024
30c1ed7
resolve conflict
HectorSVC Jan 20, 2024
55d10b2
resolved merge conflicts
HectorSVC Jan 21, 2024
ce3c64f
resolve merge conflicts
HectorSVC Jan 21, 2024
8c55f19
remove the validation mode
HectorSVC Jan 22, 2024
e7c0827
clean up some not used code
HectorSVC Jan 22, 2024
d3feaa4
renaming
HectorSVC Jan 22, 2024
33516cd
Update tests
HectorSVC Jan 23, 2024
445bc1b
fix the issue relate to initializer handling
HectorSVC Jan 25, 2024
9c7bdfc
Move QNN context cache related tests to a separate file
HectorSVC Jan 25, 2024
3dfd94b
rename some tests
HectorSVC Jan 25, 2024
3b8e879
Add UT to verify the multi-partition support
HectorSVC Jan 26, 2024
ff2c313
fill some gaps in UT and fix an issue relate to context cache path
HectorSVC Jan 26, 2024
c8ea83d
add one more UT to dumps the context cache model with 2 EPContext nodes
HectorSVC Jan 26, 2024
3dbb95d
formating
HectorSVC Jan 26, 2024
9eb32aa
Use ContribOp FusedMatMul instead of Add op to make sure the float32 …
HectorSVC Jan 26, 2024
1d4fa6f
Merge branch 'main' into qnn_ctx_multi_partition_support
HectorSVC Jan 26, 2024
74c5cef
update according review comments.
HectorSVC Feb 1, 2024
0137172
update according review comments
HectorSVC Feb 1, 2024
b79a256
formating
HectorSVC Feb 1, 2024
9e71147
revert a change
HectorSVC Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 55 additions & 57 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end());
}

if (all_ep_context_nodes.size() < 1) {
return Status::OK();
}

auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> {
for (auto& node : all_ep_context_nodes) {
if (node_name == node->Name()) {
Expand All @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers

onnxruntime::PathString context_cache_path;
PathString model_pathstring = graph.ModelPath().ToPathString();
if (all_ep_context_nodes.size() > 0) {
if (!ep_context_path.empty()) {
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}

{
if (!ep_context_path.empty()) {
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}

{
#ifdef _WIN32
std::wifstream fs(context_cache_path);
std::wifstream fs(context_cache_path);
#else
std::ifstream fs(context_cache_path);
std::ifstream fs(context_cache_path);
#endif
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
}
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
}

Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
graph.DomainToVersionMap(), {}, logger);
auto& ep_graph = ep_context_model.MainGraph();
ep_graph.SetDescription(graph.Description());

// Set inputs outputs explicitly to make sure the order is same as the user model.
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();

InlinedVector<const NodeArg*> ep_graph_inputs;
ep_graph_inputs.reserve(inputs.size());
for (auto& input : inputs) {
auto input_arg = graph.GetNodeArg(input->Name());
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
ep_graph_inputs.push_back(&ep_graph_input_arg);
}
Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
graph.DomainToVersionMap(), {}, logger);
auto& ep_graph = ep_context_model.MainGraph();
ep_graph.SetDescription(graph.Description());

InlinedVector<const NodeArg*> ep_graph_outputs;
ep_graph_outputs.reserve(outputs.size());
for (auto& output : outputs) {
auto output_arg = graph.GetNodeArg(output->Name());
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
ep_graph_outputs.push_back(&ep_graph_output_arg);
}
// Set inputs outputs explicitly to make sure the order is same as the user model.
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();

ep_graph.SetInputs(ep_graph_inputs);
ep_graph.SetOutputs(ep_graph_outputs);
InlinedVector<const NodeArg*> ep_graph_inputs;
ep_graph_inputs.reserve(inputs.size());
for (auto& input : inputs) {
auto input_arg = graph.GetNodeArg(input->Name());
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
ep_graph_inputs.push_back(&ep_graph_input_arg);
}

for (const auto& node : graph.Nodes()) {
// the fused node and EPContext node has same node name
auto ep_context_node = get_ep_context_node(node.Name());
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
if (ep_context_node.first) {
ep_graph.AddNode(*ep_context_node.second);
} else {
ep_graph.AddNode(node);
}
}
InlinedVector<const NodeArg*> ep_graph_outputs;
ep_graph_outputs.reserve(outputs.size());
for (auto& output : outputs) {
auto output_arg = graph.GetNodeArg(output->Name());
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
ep_graph_outputs.push_back(&ep_graph_output_arg);
}

// handle initializers
for (const auto& input : graph.GetInputsIncludingInitializers()) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (graph.GetInitializedTensor(input->Name(), initializer)) {
// There initializer could have duplicates so make sure we only add once
const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr;
if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) {
ep_graph.AddInitializedTensor(*initializer);
}
}
ep_graph.SetInputs(ep_graph_inputs);
ep_graph.SetOutputs(ep_graph_outputs);

for (const auto& node : graph.Nodes()) {
// the fused node and EPContext node has same node name
auto ep_context_node = get_ep_context_node(node.Name());
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
if (ep_context_node.first) {
ep_graph.AddNode(*ep_context_node.second);
} else {
ep_graph.AddNode(node);
}
}

ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));
// handle initializers
for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) {
if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) {
ep_graph.AddInitializedTensor(*initialized_tensor.second);
}
}

ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));

return Status::OK();
}

Expand Down
204 changes: 79 additions & 125 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,60 @@
namespace onnxruntime {
namespace qnn {

Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model) {
is_qnn_ctx_model = false;
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
int count = 0;
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
is_qnn_ctx_model = true;
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider.

Check warning on line 16 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:16: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
NodeAttrHelper node_helper(node);
std::string cache_source = node_helper.Get(SOURCE, "");

std::transform(cache_source.begin(),

Check warning on line 22 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for transform [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:22: Add #include <algorithm> for transform [build/include_what_you_use] [4]
cache_source.end(),
cache_source.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved

if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") {
return true;
}
++count;
}
ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node.");
}
return Status::OK();
return false;
}

bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs) {
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer);
if (has_qnn_ep_context_node) {
return true;
}
}
return false;
}

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
main_context_pos = -1;
for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) {
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph);
const auto& ep_context_node = graph_viewer.Nodes().begin();
ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node.");
qnn_models.emplace(ep_context_node->Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
NodeAttrHelper node_helper(*ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
main_context_pos = static_cast<int>(i);
}
}

ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1");
return Status::OK();
}

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
Expand All @@ -60,32 +86,18 @@
return Status::OK();
}

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
return GetEpContextFromGraph(GraphViewer(graph),
ctx_onnx_model_path,
qnn_backend_manager,
qnn_model);
}

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
qnn_model);
qnn_models);
}

std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
Expand Down Expand Up @@ -133,112 +145,54 @@
cache_file.close();
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
qnn_models);
}

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Status status;
if (is_qnn_ctx_model) {
status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);

Check warning on line 155 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:155: Lines should be <= 120 characters long [whitespace/line_length] [2]

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
model_name = graph.Name();
model_description = graph.Description();
graph_partition_name = node_helper.Get(PARTITION_NAME, "");
cache_source = node_helper.Get(SOURCE, "");

return Status::OK();
}

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
// Figure out the real context cache file path
// return true if context cache file exists
bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// always try the path set by user first, it's the only way to set it if load model from memory
if (!customer_context_cache_path.empty()) {
context_cache_path = ToPathString(customer_context_cache_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
} else if (!model_pathstring.empty()) { // model loaded from file
if (is_qnn_ctx_model) {
// it's a context cache model, just use the model path
context_cache_path = model_pathstring;
} else if (!model_pathstring.empty()) {
// this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}
}

return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger) {
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::string cache_source;
auto status = GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
cache_source,
logger);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
}

// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
if (cache_source != kQnnExecutionProvider) {
LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort.";
return Status::OK();
}

if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ",
model_name, " ", model_description, " ", graph_partition_name,
" vs epcontext: ",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message);
}

return Status::OK();
}

Status GenerateCtxCacheOnnxModel(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,

Check warning on line 192 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:192: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

Check warning on line 192 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:192: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
auto& graph = model->MainGraph();

using namespace ONNX_NAMESPACE;
Expand Down
Loading
Loading