Skip to content

Commit

Permalink
Multi-partition support for context binary cache feature (#18865)
Browse files Browse the repository at this point in the history
### Description
Multi-partition support for context binary cache feature
1. In QNNEP create the list of EPContext nodes if ep_context_enable is enabled, so that it can dump the model with multiple partitions
2. Extend context loading part to support multiple EPContext nodes

### Motivation and Context
It only support single partition before this changes. There's graph partition limitation for context cache feature after this change.
  • Loading branch information
HectorSVC authored Feb 1, 2024
1 parent eb0ce86 commit 0fa88bc
Show file tree
Hide file tree
Showing 11 changed files with 919 additions and 764 deletions.
112 changes: 55 additions & 57 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end());
}

if (all_ep_context_nodes.size() < 1) {
return Status::OK();
}

auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> {
for (auto& node : all_ep_context_nodes) {
if (node_name == node->Name()) {
Expand All @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers

onnxruntime::PathString context_cache_path;
PathString model_pathstring = graph.ModelPath().ToPathString();
if (all_ep_context_nodes.size() > 0) {
if (!ep_context_path.empty()) {
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}

{
if (!ep_context_path.empty()) {
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}

{
#ifdef _WIN32
std::wifstream fs(context_cache_path);
std::wifstream fs(context_cache_path);
#else
std::ifstream fs(context_cache_path);
std::ifstream fs(context_cache_path);
#endif
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
}
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
}

Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
graph.DomainToVersionMap(), {}, logger);
auto& ep_graph = ep_context_model.MainGraph();
ep_graph.SetDescription(graph.Description());

// Set inputs outputs explicitly to make sure the order is same as the user model.
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();

InlinedVector<const NodeArg*> ep_graph_inputs;
ep_graph_inputs.reserve(inputs.size());
for (auto& input : inputs) {
auto input_arg = graph.GetNodeArg(input->Name());
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
ep_graph_inputs.push_back(&ep_graph_input_arg);
}
Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
graph.DomainToVersionMap(), {}, logger);
auto& ep_graph = ep_context_model.MainGraph();
ep_graph.SetDescription(graph.Description());

InlinedVector<const NodeArg*> ep_graph_outputs;
ep_graph_outputs.reserve(outputs.size());
for (auto& output : outputs) {
auto output_arg = graph.GetNodeArg(output->Name());
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
ep_graph_outputs.push_back(&ep_graph_output_arg);
}
// Set inputs outputs explicitly to make sure the order is same as the user model.
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();

ep_graph.SetInputs(ep_graph_inputs);
ep_graph.SetOutputs(ep_graph_outputs);
InlinedVector<const NodeArg*> ep_graph_inputs;
ep_graph_inputs.reserve(inputs.size());
for (auto& input : inputs) {
auto input_arg = graph.GetNodeArg(input->Name());
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
ep_graph_inputs.push_back(&ep_graph_input_arg);
}

for (const auto& node : graph.Nodes()) {
// the fused node and EPContext node has same node name
auto ep_context_node = get_ep_context_node(node.Name());
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
if (ep_context_node.first) {
ep_graph.AddNode(*ep_context_node.second);
} else {
ep_graph.AddNode(node);
}
}
InlinedVector<const NodeArg*> ep_graph_outputs;
ep_graph_outputs.reserve(outputs.size());
for (auto& output : outputs) {
auto output_arg = graph.GetNodeArg(output->Name());
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
ep_graph_outputs.push_back(&ep_graph_output_arg);
}

// handle initializers
for (const auto& input : graph.GetInputsIncludingInitializers()) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (graph.GetInitializedTensor(input->Name(), initializer)) {
// There initializer could have duplicates so make sure we only add once
const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr;
if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) {
ep_graph.AddInitializedTensor(*initializer);
}
}
ep_graph.SetInputs(ep_graph_inputs);
ep_graph.SetOutputs(ep_graph_outputs);

for (const auto& node : graph.Nodes()) {
// the fused node and EPContext node has same node name
auto ep_context_node = get_ep_context_node(node.Name());
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
if (ep_context_node.first) {
ep_graph.AddNode(*ep_context_node.second);
} else {
ep_graph.AddNode(node);
}
}

ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));
// handle initializers
for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) {
if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) {
ep_graph.AddInitializedTensor(*initialized_tensor.second);
}
}

ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));

return Status::OK();
}

Expand Down
204 changes: 79 additions & 125 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,60 @@
namespace onnxruntime {
namespace qnn {

Status IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
bool& is_qnn_ctx_model) {
is_qnn_ctx_model = false;
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
int count = 0;
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
is_qnn_ctx_model = true;
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider.
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
NodeAttrHelper node_helper(node);
std::string cache_source = node_helper.Get(SOURCE, "");

std::transform(cache_source.begin(),
cache_source.end(),
cache_source.begin(),
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });

if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") {
return true;
}
++count;
}
ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node.");
}
return Status::OK();
return false;
}

bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs) {
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph);
bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer);
if (has_qnn_ep_context_node) {
return true;
}
}
return false;
}

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
main_context_pos = -1;
for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) {
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph);
const auto& ep_context_node = graph_viewer.Nodes().begin();
ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node.");
qnn_models.emplace(ep_context_node->Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
NodeAttrHelper node_helper(*ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
main_context_pos = static_cast<int>(i);
}
}

ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1");
return Status::OK();
}

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
Expand All @@ -60,32 +86,18 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
return Status::OK();
}

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
return GetEpContextFromGraph(GraphViewer(graph),
ctx_onnx_model_path,
qnn_backend_manager,
qnn_model);
}

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
qnn_model);
qnn_models);
}

std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
Expand Down Expand Up @@ -133,112 +145,54 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
cache_file.close();
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
qnn_model);
qnn_models);
}

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Status status;
if (is_qnn_ctx_model) {
status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
model_name = graph.Name();
model_description = graph.Description();
graph_partition_name = node_helper.Get(PARTITION_NAME, "");
cache_source = node_helper.Get(SOURCE, "");

return Status::OK();
}

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
// Figure out the real context cache file path
// return true if context cache file exists
bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// always try the path set by user first, it's the only way to set it if load model from memory
if (!customer_context_cache_path.empty()) {
context_cache_path = ToPathString(customer_context_cache_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
} else if (!model_pathstring.empty()) { // model loaded from file
if (is_qnn_ctx_model) {
// it's a context cache model, just use the model path
context_cache_path = model_pathstring;
} else if (!model_pathstring.empty()) {
// this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}
}

return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger) {
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::string cache_source;
auto status = GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
cache_source,
logger);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
}

// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
if (cache_source != kQnnExecutionProvider) {
LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort.";
return Status::OK();
}

if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ",
model_name, " ", model_description, " ", graph_partition_name,
" vs epcontext: ",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message);
}

return Status::OK();
}

Status GenerateCtxCacheOnnxModel(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
auto& graph = model->MainGraph();

using namespace ONNX_NAMESPACE;
Expand Down
Loading

0 comments on commit 0fa88bc

Please sign in to comment.