Skip to content

Commit

Permalink
3. Mode: run with QDQ model + QNN context model
Browse files Browse the repository at this point in the history
-- Validate QNN context model with graph partition result from QDQ model
-- In Compile(), load the QNN context model, get all the EPContext node, create QNN context from context binary, create QNN graph from the binary, and execute
  • Loading branch information
HectorSVC committed Dec 16, 2023
1 parent 8ffa12e commit 8117368
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 114 deletions.
161 changes: 116 additions & 45 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,73 @@ bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGr
return false;
}

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
main_context_pos = -1;
for (int i = 0; i < fused_nodes_and_graphs.size(); ++i) {
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph);
const auto& ep_context_node = graph_viewer.Nodes().begin();
ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node.");
qnn_models.emplace(ep_context_node->Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
NodeAttrHelper node_helper(*ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
main_context_pos = i;
}
}

ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1");
return Status::OK();
}

Status GetContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
const Node& fused_node = fused_node_and_graph.fused_node;
qnn_models.emplace(fused_node.Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
}
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());

for (const auto& ep_context_node : graph.Nodes()) {
if (EPCONTEXT_OP != ep_context_node.OpType()) {
continue;
}
NodeAttrHelper node_helper(ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
return GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models);
}
}

return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to find EPContext node with main_context=1.");
}

Status LoadContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetContextFromOnnxModel(fused_nodes_and_graphs, ctx_onnx_model_path, qnn_backend_manager, logger, qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
Expand All @@ -53,26 +120,12 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
return Status::OK();
}

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
const auto& graph = model->MainGraph();
return GetEpContextFromGraph(GraphViewer(graph),
ctx_onnx_model_path,
qnn_backend_manager,
qnn_models);
}

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
if (is_embed_mode) {
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
Expand Down Expand Up @@ -105,20 +158,13 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
qnn_models);
}

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger) {
Status status;
if (is_qnn_ctx_model) {
status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_models);
} else if (is_ctx_cache_file_exist) {
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_models, logger);
}
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}
Expand All @@ -129,19 +175,24 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::vector<std::string>& graph_partition_names,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
model_name = graph.Name();
model_description = graph.Description();
graph_partition_name = node_helper.Get(PARTITION_NAME, "");
cache_source = node_helper.Get(SOURCE, "");

for (const auto& ep_context_node : graph.Nodes()) {
if (EPCONTEXT_OP != ep_context_node.OpType()) {
continue;
}
NodeAttrHelper node_helper(ep_context_node);
cache_source = node_helper.Get(SOURCE, "");
graph_partition_names.push_back(node_helper.Get(PARTITION_NAME, ""));
}

return Status::OK();
}
Expand All @@ -159,23 +210,24 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
Status ValidateWithContextFile(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger) {
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::vector<std::string> graph_partition_names;
std::string cache_source;

auto status = GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
graph_partition_names,
cache_source,
logger);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContext model.");
}

// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
Expand All @@ -184,15 +236,34 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path
return Status::OK();
}

bool partition_names_matched = true;
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
const Node& fused_node = fused_node_graph.fused_node;
const std::string& graph_meta_id = fused_node.Name();
bool name_found = false;
for (auto partition_name_from_ctx : graph_partition_names) {
if (partition_name_from_ctx == graph_meta_id) {
name_found = true;
break;
}
}

if (!name_found) {
LOGS(logger, ERROR) << "Partition meta_id not found from any EPContext node: " << graph_meta_id;
partition_names_matched = false;
break;
}
}

if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
!partition_names_matched) {
std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ",
model_name, " ", model_description, " ", graph_partition_name,
model_name, " ", model_description,
" vs epcontext: ",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
model_description_from_ctx_cache,
" or the partition name not match.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message);
}

Expand Down
42 changes: 26 additions & 16 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);

bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs);

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
int& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
std::vector<NodeArg*>& node_args,
Expand All @@ -41,34 +47,38 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path);

Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger);
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
Status GetContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger);
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status LoadContextFromOnnxModel(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
Status ValidateWithContextFile(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger);

Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::vector<std::string>& graph_partition_names,
std::string& cache_source,
const logging::Logger& logger);

Expand Down
Loading

0 comments on commit 8117368

Please sign in to comment.