diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index caef6a3d06f17..99a91886a1ac9 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -119,11 +119,17 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, QnnBackendManager* qnn_backend_manager, QnnModel& qnn_model, const logging::Logger& logger) { + Status status; if (is_qnn_ctx_model) { - return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); + status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); } else if (is_ctx_cache_file_exist) { - return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); + status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); } + + if (Status::OK() != status) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + } + return Status::OK(); } @@ -169,32 +175,33 @@ Status ValidateWithContextFile(const std::string& context_cache_path, std::string model_description_from_ctx_cache; std::string graph_partition_name_from_ctx_cache; std::string cache_source; - ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path, - model_name_from_ctx_cache, - model_description_from_ctx_cache, - graph_partition_name_from_ctx_cache, - cache_source, - logger)); + auto status = GetMetadataFromEpContextModel(context_cache_path, + model_name_from_ctx_cache, + model_description_from_ctx_cache, + graph_partition_name_from_ctx_cache, + cache_source, + logger); + if (Status::OK() != status) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); + } // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT if (cache_source != kQnnExecutionProvider) { + LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; return Status::OK(); } - ORT_RETURN_IF(model_name != model_name_from_ctx_cache, - "Model file name from context cache metadata: " + model_name_from_ctx_cache + - " is different with target: " + model_name + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(model_description != model_description_from_ctx_cache, - "Model description from context cache metadata: " + model_description_from_ctx_cache + - " is different with target: " + model_description + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache, - "Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache + - " is different with target: " + graph_partition_name + - ". You may need to re-generate the context binary file."); + if (model_name != model_name_from_ctx_cache || + model_description != model_description_from_ctx_cache || + graph_partition_name != graph_partition_name_from_ctx_cache) { + std::string message = onnxruntime::MakeString("Metadata from Onnx file: ", + model_name, " ", model_description, " ", graph_partition_name, + " vs metadata from context cache Onnx file", + model_name_from_ctx_cache, " ", + model_description_from_ctx_cache, " ", + graph_partition_name_from_ctx_cache); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 4772da104c389..4bf6c46dbee54 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -541,13 +541,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused graph_viewer.ModelPath().ToPathString(), context_cache_path); const std::string& model_name = graph_viewer.GetGraph().Name(); - LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Name(): " << model_name; - LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Description(): " << graph_viewer.GetGraph().Description(); + const std::string& model_description = graph_viewer.GetGraph().Description(); + const std::string& graph_meta_id = fused_node.Name(); if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, model_name, - graph_viewer.GetGraph().Description(), - graph_viewer.Name(), + model_description, + graph_meta_id, logger)); } @@ -567,8 +567,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] // the name here should be same with context->node_name in compute_info - LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); - qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); @@ -580,7 +579,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, - graph_viewer.GetGraph().Description(), + model_description, context_buffer.get(), buffer_size, qnn_backend_manager_->GetSdkVersion(),