diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 943dc8128a133..d3aafcbecd322 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -404,10 +404,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } } - if (num_of_partitions > 1) { - ORT_ENFORCE(!context_cache_enabled_, "Only support single partition for context cache feature."); - } - const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, ", number of nodes in the graph: ", num_nodes_in_graph, ", number of nodes supported by QNN: ", num_of_supported_nodes); @@ -485,7 +481,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists(); if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); + ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); // Load and execute from cached context if exist ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer, @@ -509,7 +505,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature."); + ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(),