Skip to content

Commit

Permalink
Return INVALID_GRAPH if failed to load from QNN context binary
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Nov 17, 2023
1 parent 9a8092f commit 8253a2d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 29 deletions.
51 changes: 29 additions & 22 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,17 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Status status;
if (is_qnn_ctx_model) {
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}

if (Status::OK() != status) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Expand Down Expand Up @@ -169,32 +175,33 @@ Status ValidateWithContextFile(const std::string& context_cache_path,
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::string cache_source;
ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
cache_source,
logger));
auto status = GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
cache_source,
logger);
if (Status::OK() != status) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
}

// The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT
if (cache_source != kQnnExecutionProvider) {
LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort.";
return Status::OK();
}

ORT_RETURN_IF(model_name != model_name_from_ctx_cache,
"Model file name from context cache metadata: " + model_name_from_ctx_cache +
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(model_description != model_description_from_ctx_cache,
"Model description from context cache metadata: " + model_description_from_ctx_cache +
" is different with target: " + model_description +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache,
"Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache +
" is different with target: " + graph_partition_name +
". You may need to re-generate the context binary file.");
if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
std::string message = onnxruntime::MakeString("Metadata from Onnx file: ",
model_name, " ", model_description, " ", graph_partition_name,
" vs metadata from context cache Onnx file",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message);
}

return Status::OK();
}
Expand Down
13 changes: 6 additions & 7 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,13 +541,13 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
graph_viewer.ModelPath().ToPathString(),
context_cache_path);
const std::string& model_name = graph_viewer.GetGraph().Name();
LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Name(): " << model_name;
LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Description(): " << graph_viewer.GetGraph().Description();
const std::string& model_description = graph_viewer.GetGraph().Description();
const std::string& graph_meta_id = fused_node.Name();
if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) {
ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path,
model_name,
graph_viewer.GetGraph().Description(),
graph_viewer.Name(),
model_description,
graph_meta_id,
logger));
}

Expand All @@ -567,8 +567,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused

// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
// the name here should be same with context->node_name in compute_info
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
qnn_models_.emplace(graph_meta_id, std::move(qnn_model));

Check warning on line 570 in onnxruntime/core/providers/qnn/qnn_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/qnn_execution_provider.cc#L570

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/qnn/qnn_execution_provider.cc:570:  Add #include <utility> for move  [build/include_what_you_use] [4]

ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
return Status::OK();
Expand All @@ -580,7 +579,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name,
graph_viewer.GetGraph().Description(),
model_description,
context_buffer.get(),
buffer_size,
qnn_backend_manager_->GetSdkVersion(),
Expand Down

0 comments on commit 8253a2d

Please sign in to comment.