From eba7e03318abccd3f8d1f674ee3e53f8309a8d50 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 13 Dec 2023 16:03:59 -0800 Subject: [PATCH 01/28] Support multi-partition for context cache feature 1. dump model part --- .../core/framework/execution_provider.h | 8 ++++++ .../core/framework/graph_partitioner.cc | 27 +++++++++++++++++++ .../qnn/builder/onnx_ctx_model_helper.cc | 11 ++------ .../qnn/builder/onnx_ctx_model_helper.h | 3 +-- .../providers/qnn/qnn_execution_provider.cc | 20 +++++++++++--- .../providers/qnn/qnn_execution_provider.h | 3 +++ onnxruntime/core/session/inference_session.cc | 2 ++ 7 files changed, 59 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index ea4f52f99649d..90ce57e1787d4 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -326,6 +326,14 @@ class IExecutionProvider { */ virtual std::vector CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const std::vector GetEpContextNodes() const { + return std::vector(); + } + private: const std::string type_; diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c7564548..5f4ba4c373a2a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -510,6 +510,33 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(graph.Resolve()); } + const std::vector ep_context_nodes = current_ep.GetEpContextNodes(); + auto get_ep_context_node = [&ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + if (ep_context_nodes.size() > 0) { + Model ep_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, *current_ep.GetLogger()); + auto& ep_graph = ep_model.MainGraph(); + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by current EP if name matched, otherwise use original node + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + ORT_RETURN_IF_ERROR(Model::Save(ep_model, "ep_partition.onnx")); + } + // For some cases, like fp16 on cpu, right now we don't have any kernel support that. // But we will insert cast op to run the model, so skip the error checking here. // If after graph transform phase, the node still not assigned, we will report error diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 234b957816662..0dc3d6f56697e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -206,8 +206,7 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path return Status::OK(); } -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, @@ -216,11 +215,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, const logging::Logger& logger) { - std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, {}, logger); - auto& graph = model.MainGraph(); - graph.SetDescription(model_description); + auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; int index = 0; @@ -272,8 +267,6 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, ep_node.AddAttribute(SOURCE, kQnnExecutionProvider); ++index; } - ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 0011d0f43f5bc..ba6fe23ecd56e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -73,8 +73,7 @@ Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_mod std::string& cache_source, const logging::Logger& logger); -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 60f7bbe08cb6a..bedb22ff11ac2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -569,7 +569,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + //ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); // Load and execute from cached context if exist ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, @@ -592,11 +592,14 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + //ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, - model_description, + std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; + Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, logger); + qnn_ep_context_model_ = std::make_unique(model_name, false, logger); + ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), context_buffer.get(), buffer_size, qnn_backend_manager_->GetSdkVersion(), @@ -608,4 +611,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } return Status::OK(); } + +const std::vector QNNExecutionProvider::GetEpContextNodes() const { + std::vector ep_context_nodes; + const auto& graph = qnn_ep_context_model_->MainGraph(); + for (const auto& node : graph.Nodes()) { ; + ep_context_nodes.push_back(graph.GetNode(node.Index())); + } + return ep_context_nodes; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 8b5d0929209ee..224842546e789 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -35,6 +35,8 @@ class QNNExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; + const std::vector GetEpContextNodes() const override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, @@ -66,6 +68,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; + std::unique_ptr qnn_ep_context_model_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 575529a06fb7a..fcc33a75ce9a0 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1198,6 +1198,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } + ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, "partitioned_graph.onnx")); + #ifdef ENABLE_TRAINING // Enable memory optimizations (mainly insert recomputation nodes with priority). // Only applicable for training scenarios. From 8ffa12e81a57b533170831061581fe5ce86068e7 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 15 Dec 2023 00:17:19 -0800 Subject: [PATCH 02/28] 2. Load and execute the model with multiple EPContext --- .../core/framework/graph_partitioner.cc | 1 + .../qnn/builder/onnx_ctx_model_helper.cc | 45 +++---- .../qnn/builder/onnx_ctx_model_helper.h | 11 +- .../qnn/builder/qnn_backend_manager.cc | 15 ++- .../qnn/builder/qnn_backend_manager.h | 3 +- .../core/providers/qnn/builder/qnn_model.cc | 3 +- .../providers/qnn/qnn_execution_provider.cc | 117 +++++++++++------- 7 files changed, 113 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 5f4ba4c373a2a..55e3bf6e05a7c 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -524,6 +524,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Model ep_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), {}, *current_ep.GetLogger()); auto& ep_graph = ep_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); for (const auto& node : graph.Nodes()) { // the fused node and EPContext node has same node name auto ep_context_node = get_ep_context_node(node.Name()); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 0dc3d6f56697e..9464c8fca8df4 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -12,28 +12,21 @@ namespace onnxruntime { namespace qnn { -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model) { - is_qnn_ctx_model = false; - for (const auto& fused_node_graph : fused_nodes_and_graphs) { - const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - int count = 0; - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { - is_qnn_ctx_model = true; - } - ++count; +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { + // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type + for (const auto& node : graph_viewer.Nodes()) { + if (EPCONTEXT_OP == node.OpType()) { + return true; } - ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node."); } - return Status::OK(); + return false; } -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) { - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); + bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer); + if (has_qnn_ep_context_node) { return true; } } @@ -62,7 +55,7 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, + std::unordered_map>& qnn_models, const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; @@ -71,13 +64,13 @@ Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, return GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, qnn_backend_manager, - qnn_model); + qnn_models); } Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model) { + std::unordered_map>& qnn_models) { const auto& node = graph_viewer.Nodes().begin(); NodeAttrHelper node_helper(*node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -85,7 +78,7 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), - qnn_model); + qnn_models); } std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); @@ -109,7 +102,7 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), - qnn_model); + qnn_models); } Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, @@ -117,13 +110,13 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, bool is_qnn_ctx_model, bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, + std::unordered_map>& qnn_models, const logging::Logger& logger) { Status status; if (is_qnn_ctx_model) { - status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); + status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_models); } else if (is_ctx_cache_file_exist) { - status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); + status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_models, logger); } if (!status.IsOK()) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index ba6fe23ecd56e..be55c4c0df740 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,10 +28,9 @@ static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model); +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer); +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs); Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, @@ -44,20 +43,20 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, + std::unordered_map>& qnn_models, const logging::Logger& logger); Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model); + std::unordered_map>& qnn_models); Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, bool is_qnn_ctx_model, bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, + std::unordered_map>& qnn_models, const logging::Logger& logger); Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 38d74909db86b..8f0f92def3f97 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -483,7 +483,8 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } -Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -516,8 +517,9 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t graphs_info = binary_info->contextBinaryInfoV2.graphs; } - ORT_RETURN_IF(graph_count > 1, "Load from Qnn cached context only support 1 sub-graph."); - ORT_RETURN_IF(graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size(); + ORT_RETURN_IF(graph_count != qnn_models.size(), "Graph count from QNN context not equal to EPContext node count."); ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); @@ -537,7 +539,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } qnn_sys_interface_.systemContextFree(sys_ctx_handle); sys_ctx_handle = nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index bc05820da2f73..a3d9dbce7f13c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -75,7 +75,8 @@ class QnnBackendManager { std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); - Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model); + Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index fd3a95b5f1f78..172df8e594505 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -97,7 +97,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, std::unordered_map node_unit_map; std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); - const auto& graph_name = graph_viewer.Name(); + // This name must be same with the EPContext node name + const auto& graph_name = fused_node.Name(); ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index bedb22ff11ac2..25edc15e557e5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -251,11 +251,12 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, bool load_from_cached_context, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Enable Qnn context cache requires the whole graph partitioned to Qnn EP - // Blindly filter in all nodes if context cache is enabled + // Filter in the EPContext node if the model has such nodes if (load_from_cached_context) { for (const auto& node : graph_viewer.Nodes()) { - supported_nodes.insert(&node); + if (qnn::EPCONTEXT_OP == node.OpType()) { + supported_nodes.insert(&node); + } } return supported_nodes; } @@ -341,13 +342,13 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; - bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); - if (is_qnn_ctx_model) { + bool has_ep_context_node = qnn::GraphHasEpContextNode(graph_viewer); + if (has_ep_context_node) { load_from_cached_context = true; } // This is for case: QDQ model + Onnx Qnn context cache model - if (context_cache_enabled_ && !is_qnn_ctx_model) { + if (context_cache_enabled_ && !has_ep_context_node) { onnxruntime::PathString context_cache_path; load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, graph_viewer.ModelPath().ToPathString(), @@ -361,11 +362,12 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer return result; } - if ((context_cache_enabled_ || is_qnn_ctx_model) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { + if ((context_cache_enabled_ || has_ep_context_node) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { LOGS(logger, ERROR) << "Qnn context cache only works for HTP or DSP backend."; return result; } + // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -426,7 +428,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (partition && partition->sub_graph) { nodes_in_partition = partition->sub_graph->nodes.size(); - if (nodes_in_partition == 1) { + if (nodes_in_partition == 1 && !load_from_cached_context) { const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); if (!node) { @@ -457,7 +459,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Print list of unsupported nodes to the ERROR logger if the CPU EP // has been disabled for this inference session. - if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + if (!load_from_cached_context && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); } @@ -547,58 +549,82 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); - Node& fused_node = fused_nodes_and_graphs[0].fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); + //const Node& fused_node_0 = fused_nodes_and_graphs[0].fused_node; + //const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - bool is_qnn_ctx_model = false; - ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); + bool is_ctx_file_exist = false; onnxruntime::PathString context_cache_path; - bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - const std::string& model_name = graph_viewer.GetGraph().Name(); - const std::string& model_description = graph_viewer.GetGraph().Description(); - const std::string& graph_meta_id = fused_node.Name(); - if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, - model_name, - model_description, - graph_meta_id, - logger)); - } + //bool is_ctx_file_exist = is_qnn_ctx_model ? false : + // qnn::IsContextCacheFileExists(context_cache_path_cfg_, + // graph_viewer_0.ModelPath().ToPathString(), + // context_cache_path); + //const std::string& model_name = graph_viewer_0.GetGraph().Name(); + //const std::string& model_description = graph_viewer_0.GetGraph().Description(); + //const std::string& graph_meta_id0 = fused_node_0.Name(); + //if (!is_qnn_ctx_model && is_ctx_file_exist) { + // ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, + // model_name, + // model_description, + // graph_meta_id0, + // logger)); + //} + if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - //ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); - std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); + // Table, the node name is not same with the graph_meta_id + std::unordered_map> qnn_models; + int main_context_pos = -1; + for (int i = 0; i < fused_nodes_and_graphs.size(); ++i) { + const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + qnn_models.emplace(ep_context_node->Name(), + std::make_unique(logger, qnn_backend_manager_.get())); + NodeAttrHelper node_helper(*ep_context_node); + bool is_main_context = node_helper.Get(qnn::MAIN_CONTEXT, static_cast(0)); + if (is_main_context) { + main_context_pos = i; + } + } + + ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(main_ctx_graph_viewer, context_cache_path, is_qnn_ctx_model, is_ctx_file_exist, qnn_backend_manager_.get(), - *(qnn_model.get()), + qnn_models, logger)); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); - // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] - // the name here should be same with context->node_name in compute_info - qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + for (auto fused_node_and_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& fused_node = fused_node_and_graph.fused_node; + auto qnn_model = std::move(qnn_models[ep_context_node->Name()]); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + + // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] + // the name here must be same with context->node_name in compute_info + const std::string& graph_meta_id = fused_node.Name(); + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + + ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + } - ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); if (context_cache_enabled_ && !is_qnn_ctx_model) { - //ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, {}, logger); - qnn_ep_context_model_ = std::make_unique(model_name, false, logger); + qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), context_buffer.get(), buffer_size, @@ -614,10 +640,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const std::vector QNNExecutionProvider::GetEpContextNodes() const { std::vector ep_context_nodes; - const auto& graph = qnn_ep_context_model_->MainGraph(); - for (const auto& node : graph.Nodes()) { ; - ep_context_nodes.push_back(graph.GetNode(node.Index())); + if (qnn_ep_context_model_) { + const auto& graph = qnn_ep_context_model_->MainGraph(); + for (const auto& node : graph.Nodes()) { + ep_context_nodes.push_back(graph.GetNode(node.Index())); + } } + return ep_context_nodes; } } // namespace onnxruntime From 8117368a6fe5e1f3f674de627a1e5a03ed8e46d6 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 15 Dec 2023 23:36:01 -0800 Subject: [PATCH 03/28] 3. Mode: run with QDQ model + QNN context model -- Validate QNN context model with graph partition result from QDQ model -- In Compile(), load the QNN context model, get all the EPContext node, create QNN context from context binary, create QNN graph from the binary, and execute --- .../qnn/builder/onnx_ctx_model_helper.cc | 161 +++++++++++++----- .../qnn/builder/onnx_ctx_model_helper.h | 42 +++-- .../providers/qnn/qnn_execution_provider.cc | 101 ++++++----- 3 files changed, 190 insertions(+), 114 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 9464c8fca8df4..a9075cb984734 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -33,6 +33,73 @@ bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models) { + main_context_pos = -1; + for (int i = 0; i < fused_nodes_and_graphs.size(); ++i) { + const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); + qnn_models.emplace(ep_context_node->Name(), + std::make_unique(logger, qnn_backend_manager)); + NodeAttrHelper node_helper(*ep_context_node); + int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); + if (1 == is_main_context) { + main_context_pos = i; + } + } + + ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + return Status::OK(); +} + +Status GetContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + std::unordered_map>& qnn_models) { + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { + const Node& fused_node = fused_node_and_graph.fused_node; + qnn_models.emplace(fused_node.Name(), + std::make_unique(logger, qnn_backend_manager)); + } + using namespace onnxruntime; + std::shared_ptr model; + ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); + const auto& graph = GraphViewer(model->MainGraph()); + + for (const auto& ep_context_node : graph.Nodes()) { + if (EPCONTEXT_OP != ep_context_node.OpType()) { + continue; + } + NodeAttrHelper node_helper(ep_context_node); + int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); + if (1 == is_main_context) { + return GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models); + } + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to find EPContext node with main_context=1."); +} + +Status LoadContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + std::unordered_map>& qnn_models) { + Status status = GetContextFromOnnxModel(fused_nodes_and_graphs, ctx_onnx_model_path, qnn_backend_manager, logger, qnn_models); + + // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + } + + return Status::OK(); +} + Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, @@ -53,26 +120,12 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); - const auto& graph = model->MainGraph(); - return GetEpContextFromGraph(GraphViewer(graph), - ctx_onnx_model_path, - qnn_backend_manager, - qnn_models); -} - -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models) { - const auto& node = graph_viewer.Nodes().begin(); - NodeAttrHelper node_helper(*node); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models) { + ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); + NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); if (is_embed_mode) { const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); @@ -105,20 +158,13 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, qnn_models); } -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models, - const logging::Logger& logger) { - Status status; - if (is_qnn_ctx_model) { - status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_models); - } else if (is_ctx_cache_file_exist) { - status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_models, logger); - } + std::unordered_map>& qnn_models) { + Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); + // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); } @@ -129,19 +175,24 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, std::string& model_name, std::string& model_description, - std::string& graph_partition_name, + std::vector& graph_partition_names, std::string& cache_source, const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); const auto& graph = GraphViewer(model->MainGraph()); - const auto& node = graph.Nodes().begin(); - NodeAttrHelper node_helper(*node); model_name = graph.Name(); model_description = graph.Description(); - graph_partition_name = node_helper.Get(PARTITION_NAME, ""); - cache_source = node_helper.Get(SOURCE, ""); + + for (const auto& ep_context_node : graph.Nodes()) { + if (EPCONTEXT_OP != ep_context_node.OpType()) { + continue; + } + NodeAttrHelper node_helper(ep_context_node); + cache_source = node_helper.Get(SOURCE, ""); + graph_partition_names.push_back(node_helper.Get(PARTITION_NAME, "")); + } return Status::OK(); } @@ -159,23 +210,24 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, +Status ValidateWithContextFile(const std::vector& fused_nodes_and_graphs, + const onnxruntime::PathString& context_cache_path, const std::string& model_name, const std::string& model_description, - const std::string& graph_partition_name, const logging::Logger& logger) { std::string model_name_from_ctx_cache; std::string model_description_from_ctx_cache; - std::string graph_partition_name_from_ctx_cache; + std::vector graph_partition_names; std::string cache_source; + auto status = GetMetadataFromEpContextModel(context_cache_path, model_name_from_ctx_cache, model_description_from_ctx_cache, - graph_partition_name_from_ctx_cache, + graph_partition_names, cache_source, logger); if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContext model."); } // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT @@ -184,15 +236,34 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path return Status::OK(); } + bool partition_names_matched = true; + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const Node& fused_node = fused_node_graph.fused_node; + const std::string& graph_meta_id = fused_node.Name(); + bool name_found = false; + for (auto partition_name_from_ctx : graph_partition_names) { + if (partition_name_from_ctx == graph_meta_id) { + name_found = true; + break; + } + } + + if (!name_found) { + LOGS(logger, ERROR) << "Partition meta_id not found from any EPContext node: " << graph_meta_id; + partition_names_matched = false; + break; + } + } + if (model_name != model_name_from_ctx_cache || model_description != model_description_from_ctx_cache || - graph_partition_name != graph_partition_name_from_ctx_cache) { + !partition_names_matched) { std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", - model_name, " ", model_description, " ", graph_partition_name, + model_name, " ", model_description, " vs epcontext: ", model_name_from_ctx_cache, " ", - model_description_from_ctx_cache, " ", - graph_partition_name_from_ctx_cache); + model_description_from_ctx_cache, + " or the partition name not match."); return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index be55c4c0df740..de2f0bd702924 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -32,6 +32,12 @@ bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs); +Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models); + Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, @@ -41,34 +47,38 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, const onnxruntime::PathString& model_pathstring, onnxruntime::PathString& context_cache_path); -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models, - const logging::Logger& logger); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models); -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models); +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models); -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status GetContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models, - const logging::Logger& logger); + const logging::Logger& logger, + std::unordered_map>& qnn_models); + +Status LoadContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + std::unordered_map>& qnn_models); -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, +Status ValidateWithContextFile(const std::vector& fused_nodes_and_graphs, + const onnxruntime::PathString& context_cache_path, const std::string& model_name, const std::string& model_description, - const std::string& graph_partition_name, const logging::Logger& logger); Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, std::string& model_name, std::string& model_description, - std::string& graph_partition_name, + std::vector& graph_partition_names, std::string& cache_source, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 25edc15e557e5..30f3fddf59019 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -248,13 +248,17 @@ std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool load_from_cached_context, + bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Filter in the EPContext node if the model has such nodes - if (load_from_cached_context) { + // Filter in the EPContext node if its QNN Context model + if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { if (qnn::EPCONTEXT_OP == node.OpType()) { + LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() + << "] name: [" << node.Name() + << "] Operator type: [EPContext" + << "] index: [" << node.Index() << "]"; supported_nodes.insert(&node); } } @@ -342,32 +346,32 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; - bool has_ep_context_node = qnn::GraphHasEpContextNode(graph_viewer); - if (has_ep_context_node) { + bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); + if (is_qnn_ctx_model) { load_from_cached_context = true; } // This is for case: QDQ model + Onnx Qnn context cache model - if (context_cache_enabled_ && !has_ep_context_node) { + if (context_cache_enabled_ && !is_qnn_ctx_model) { onnxruntime::PathString context_cache_path; load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, graph_viewer.ModelPath().ToPathString(), context_cache_path); } - // Load from cached context will load the QnnSystem lib and skip the Qnn context creation + // It will load the QnnSystem lib if load_from_cached_context=true, and + // delay the Qnn context creation to Compile() using the cached context auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; } - if ((context_cache_enabled_ || has_ep_context_node) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { + if ((context_cache_enabled_ || is_qnn_ctx_model) && !IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) { LOGS(logger, ERROR) << "Qnn context cache only works for HTP or DSP backend."; return result; } - // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -375,7 +379,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), - load_from_cached_context, logger); + is_qnn_ctx_model, logger); // Helper function that returns a string that lists all unsupported nodes. // Ex: { name: mul_123, type: Mul }, {}, ... @@ -549,68 +553,58 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); - //const Node& fused_node_0 = fused_nodes_and_graphs[0].fused_node; - //const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); - bool is_ctx_file_exist = false; onnxruntime::PathString context_cache_path; - //bool is_ctx_file_exist = is_qnn_ctx_model ? false : - // qnn::IsContextCacheFileExists(context_cache_path_cfg_, - // graph_viewer_0.ModelPath().ToPathString(), - // context_cache_path); - //const std::string& model_name = graph_viewer_0.GetGraph().Name(); - //const std::string& model_description = graph_viewer_0.GetGraph().Description(); - //const std::string& graph_meta_id0 = fused_node_0.Name(); - //if (!is_qnn_ctx_model && is_ctx_file_exist) { - // ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, - // model_name, - // model_description, - // graph_meta_id0, - // logger)); - //} - + bool is_ctx_file_exist = false; + if (!is_qnn_ctx_model) { + const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); + is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); + if (is_ctx_file_exist) { + ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(fused_nodes_and_graphs, + context_cache_path, + graph_viewer_0.GetGraph().Name(), + graph_viewer_0.GetGraph().Description(), + logger)); + } + } if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { // Table, the node name is not same with the graph_meta_id std::unordered_map> qnn_models; - int main_context_pos = -1; - for (int i = 0; i < fused_nodes_and_graphs.size(); ++i) { - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); - const auto& ep_context_node = graph_viewer.Nodes().begin(); - qnn_models.emplace(ep_context_node->Name(), - std::make_unique(logger, qnn_backend_manager_.get())); - NodeAttrHelper node_helper(*ep_context_node); - bool is_main_context = node_helper.Get(qnn::MAIN_CONTEXT, static_cast(0)); - if (is_main_context) { - main_context_pos = i; - } - } - ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + if (is_qnn_ctx_model) { + int main_context_pos = -1; + ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), + logger, main_context_pos, qnn_models)); - const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); - // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(main_ctx_graph_viewer, - context_cache_path, - is_qnn_ctx_model, - is_ctx_file_exist, - qnn_backend_manager_.get(), - qnn_models, - logger)); + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, + context_cache_path, + qnn_backend_manager_.get(), + qnn_models)); + } else { + ORT_RETURN_IF_ERROR(qnn::LoadContextFromOnnxModel(fused_nodes_and_graphs, context_cache_path, + qnn_backend_manager_.get(), logger, qnn_models)); + } for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); const auto& ep_context_node = graph_viewer.Nodes().begin(); const Node& fused_node = fused_node_and_graph.fused_node; - auto qnn_model = std::move(qnn_models[ep_context_node->Name()]); + const std::string& graph_meta_id = fused_node.Name(); + std::string key = is_qnn_ctx_model ? ep_context_node->Name() : graph_meta_id; + ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); + auto qnn_model = std::move(qnn_models[key]); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] // the name here must be same with context->node_name in compute_info - const std::string& graph_meta_id = fused_node.Name(); qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); @@ -620,7 +614,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); - if (context_cache_enabled_ && !is_qnn_ctx_model) { + // Generate QNN context model if it's QDQ model + context_cache_enabled=true + not exist already + if (!is_qnn_ctx_model && context_cache_enabled_ && !is_ctx_file_exist) { // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); From 8a00784019a07a3755e11731274133cb62ef12d2 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 18 Dec 2023 01:33:05 -0800 Subject: [PATCH 04/28] Remove QNN EP options: qnn_context_cache_enable, qnn_context_cache_path, qnn_context_embed_mode. Add session option accordingly. --- .../core/session/onnxruntime_c_api.h | 6 - .../onnxruntime_session_options_config_keys.h | 18 +++ .../core/framework/graph_partitioner.cc | 103 ++++++++++++------ .../core/framework/graph_partitioner.h | 3 + .../qnn/builder/onnx_ctx_model_helper.cc | 12 +- .../providers/qnn/qnn_execution_provider.cc | 30 ++--- onnxruntime/core/session/inference_session.cc | 14 ++- onnxruntime/test/onnx/main.cc | 20 +++- .../test/perftest/command_args_parser.cc | 2 - onnxruntime/test/perftest/ort_test_session.cc | 6 - 10 files changed, 135 insertions(+), 79 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c41700453a73b..dbd5ad41255fa 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3593,17 +3593,11 @@ struct OrtApi { * * QNN supported keys: * "backend_path": file path to QNN backend library. - * "qnn_context_cache_enable": 1 to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will - * load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist - * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided. * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. * "rpc_control_latency": QNN RPC control latency. * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". - * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model. - * 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context. - * The path is relative path to the ONNX skeleton model file. * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index a94973b2cc5d7..c0f503ea02821 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -235,3 +235,21 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; + +// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. +// "0": disable. (default) +// "1": enable. +static const char* const kOrtSessionOptionEpContextEnable = "ep.ep_context_enable"; + +// Specify the file path for the Onnx model which has EP context. +// Default to original_file_name_ctx.onnx if not specified +static const char* const kOrtSessionOptionEpContextFilePath = "ep.ep_context_file_path"; + +// Flag to specify whether to dump the EP context into the Onnx model. +// "0": dump the EP context into separate file, keep the file name in the Onnx model. +// "1": dump the EP context into the Onnx model. (default). +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.ep_context_embed_mode"; + +// Dump the model after graph partitioning to file "partitioned_graph.onnx". +static const char* const kDumpPartitionedGraph = "session.dump_partitioned_graph"; \ No newline at end of file diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 55e3bf6e05a7c..f6f1a1e6aba93 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -510,34 +511,6 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(graph.Resolve()); } - const std::vector ep_context_nodes = current_ep.GetEpContextNodes(); - auto get_ep_context_node = [&ep_context_nodes](const std::string& node_name) -> std::pair { - for (auto& node : ep_context_nodes) { - if (node_name == node->Name()) { - return std::make_pair(true, node); - } - } - return std::make_pair(false, static_cast(nullptr)); - }; - - if (ep_context_nodes.size() > 0) { - Model ep_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph.DomainToVersionMap(), {}, *current_ep.GetLogger()); - auto& ep_graph = ep_model.MainGraph(); - ep_graph.SetDescription(graph.Description()); - for (const auto& node : graph.Nodes()) { - // the fused node and EPContext node has same node name - auto ep_context_node = get_ep_context_node(node.Name()); - // Use EpContext node created by current EP if name matched, otherwise use original node - if (ep_context_node.first) { - ep_graph.AddNode(*ep_context_node.second); - } else { - ep_graph.AddNode(node); - } - } - ORT_RETURN_IF_ERROR(Model::Save(ep_model, "ep_partition.onnx")); - } - // For some cases, like fp16 on cpu, right now we don't have any kernel support that. // But we will insert cast op to run the model, so skip the error checking here. // If after graph transform phase, the node still not assigned, we will report error @@ -662,9 +635,68 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + std::vector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const std::vector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + if (all_ep_context_nodes.size() > 0) { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + bool file_exist = std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); + + if (file_exist) { + // User need to remove the existing file if want to re-generate it + LOGS(logger, INFO) << "Ep context file exist already."; + return Status::OK(); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + bool ep_context_enabled, + std::string ep_context_path, + const logging::Logger& logger) { bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -682,6 +714,10 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, partition_params.debug_graph_fn)); } + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(execution_providers, graph, ep_context_path, logger)); + } + // expand any nodes that have an ONNX function definition but no matching ORT kernel. modified_graph = false; ORT_RETURN_IF_ERROR(InlineNodes(graph, modified_graph)); @@ -868,6 +904,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -912,8 +950,11 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, - providers_, kernel_registry_mgr_)); + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, + kernel_registry_mgr_, ep_context_enabled, + ep_context_path, logger)); #else return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c2588260..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index a9075cb984734..08f862d2b4dcb 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -57,10 +57,10 @@ Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, - std::unordered_map>& qnn_models) { + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + std::unordered_map>& qnn_models) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { const Node& fused_node = fused_node_and_graph.fused_node; qnn_models.emplace(fused_node.Name(), @@ -204,7 +204,7 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_qnn_ctx.onnx"); + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); } return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); @@ -305,7 +305,7 @@ Status GenerateCtxCacheOnnxModel(Model* model, nullptr, kMSDomain); - // Only dump the context buffer once since all QNN graph are in one single context + // Only dump the context buffer once since all QNN graphs are in one single context if (0 == index) { if (qnn_context_embed_mode) { std::string cache_payload(buffer, buffer + buffer_size); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 30f3fddf59019..7acec57f51db5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -114,29 +114,17 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio if (session_options) { disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault( kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; - } - - static const std::string CONTEXT_CACHE_ENABLED = "qnn_context_cache_enable"; - auto context_cache_enabled_pos = provider_options_map.find(CONTEXT_CACHE_ENABLED); - if (context_cache_enabled_pos != provider_options_map.end()) { - if (context_cache_enabled_pos->second == "1") { - context_cache_enabled_ = true; - LOGS_DEFAULT(VERBOSE) << "Context cache enabled."; - } - } - static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path"; - auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH); - if (context_cache_path_pos != provider_options_map.end()) { - context_cache_path_cfg_ = context_cache_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; - } + context_cache_enabled_ = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEnable, "0") == "1"; + LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled_; - static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode"; - auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE); - if (context_cache_embed_mode_pos != provider_options_map.end()) { - qnn_context_embed_mode_ = context_cache_embed_mode_pos->second == "1"; + qnn_context_embed_mode_ = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEmbedMode, "1") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; + + context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } static const std::string BACKEND_PATH = "backend_path"; @@ -557,7 +545,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); onnxruntime::PathString context_cache_path; - bool is_ctx_file_exist = false; + bool is_ctx_file_exist = false; if (!is_qnn_ctx_model) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index fcc33a75ce9a0..fda67b3685bd9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1166,6 +1166,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Do partitioning based on execution providers' capabilities. ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, + session_options_.config_options, *session_logger_, mode, debug_graph_fn)); // apply Level2 and higher transformers. @@ -1198,7 +1199,10 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } - ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, "partitioned_graph.onnx")); + bool dump_partitioned_graph = session_options_.config_options.GetConfigOrDefault(kDumpPartitionedGraph, "0") == "1"; + if (dump_partitioned_graph) { + ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, "partitioned_graph.onnx")); + } #ifdef ENABLE_TRAINING // Enable memory optimizations (mainly insert recomputation nodes with priority). @@ -1462,7 +1466,9 @@ namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, - SessionState& session_state) { + SessionState& session_state, + const ConfigOptions& config_options, + const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1483,6 +1489,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, + config_options, + logger, GraphPartitioner::Mode::kOrtFormatLoad)); return Status::OK(); @@ -1836,7 +1844,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_)); + *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 646ff7c95b229..cc0c9c69754fe 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -50,15 +50,12 @@ void usage() { "\t-a: Specify custom absolute tolerance values for output value comparison. default: 1e-5\n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" - "\t [QNN only] [qnn_context_cache_enable]: 1 to enable cache QNN context. Default to false.\n" - "\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" - "\t [QNN only] [qnn_context_embed_mode]: 1 means dump the QNN context binary into the Onnx skeleton model.\n" "\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" @@ -73,6 +70,8 @@ void usage() { "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" "\t-o [optimization level]: Default is 99. Valid values are 0 (disable), 1 (basic), 2 (extended), 99 (all).\n" "\t\tPlease see onnxruntime_c_api.h (enum GraphOptimizationLevel) for the full list of all optimization levels. " + "\t-f: Enable EP context cache generation.\n" + "\t-b: Disable EP context embed mode.\n" "\n" "\t-h: help\n" "\n" @@ -179,11 +178,13 @@ int real_main(int argc, char* argv[], Ort::Env& env) { OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_ERROR; bool verbose_logging_required = false; + bool ep_context_enable = false; + bool disable_ep_context_embed_mode = false; bool pause = false; { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:i:pz"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:i:pzfb"))) != -1) { switch (ch) { case 'A': enable_cpu_mem_arena = false; @@ -312,6 +313,12 @@ int real_main(int argc, char* argv[], Ort::Env& env) { case 'z': set_denormal_as_zero = true; break; + case 'b': + disable_ep_context_embed_mode = true; + break; + case 'f': + ep_context_enable = true; + break; case '?': case 'h': default: @@ -386,6 +393,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (set_denormal_as_zero) sf.AddConfigEntry(kOrtSessionOptionsConfigSetDenormalAsZero, "1"); + if (ep_context_enable) + sf.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + if (disable_ep_context_embed_mode) + sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + if (enable_tensorrt) { #ifdef USE_TENSORRT OrtCUDAProviderOptions cuda_options; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 27e26fe0b3c45..6e3252aaeb4b8 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -65,8 +65,6 @@ namespace perftest { "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" - "\t [QNN only] [qnn_context_cache_enable]: 1 to enable cache QNN context. Default to false.\n" - "\t [QNN only] [qnn_context_cache_path]: File path to the qnn context cache. Default to model_file.onnx.bin if not set.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index eb2a77c07f803..abec16f787895 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -332,12 +332,6 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (value.empty()) { ORT_THROW("Please provide the QNN backend path."); } - } else if (key == "qnn_context_cache_enable") { - if (value != "1") { - ORT_THROW("Set to 1 to enable qnn_context_cache_enable."); - } - } else if (key == "qnn_context_cache_path") { - // no validation } else if (key == "profiling_level") { std::set supported_profiling_level = {"off", "basic", "detailed"}; if (supported_profiling_level.find(value) == supported_profiling_level.end()) { From 16058a8f185f15613922bfeaa5dae6a8fd1d3674 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 18 Dec 2023 09:12:24 -0800 Subject: [PATCH 05/28] update test code --- onnxruntime/test/framework/session_state_test.cc | 12 +++++++++--- onnxruntime/test/onnx/main.cc | 10 ++-------- onnxruntime/test/perftest/ort_test_session.cc | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index e1ce1d4abf81d..a6aa57dc4a81a 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -176,7 +176,9 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { AllocatorPtr cpu_allocator = std::make_shared(); return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -256,7 +258,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP(graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -313,7 +317,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); // Finalize the session state ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index cc0c9c69754fe..51edb91b5d3af 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -478,12 +478,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (value != "0") { ORT_THROW("Set to 0 to disable qnn_context_embed_mode."); } - } else if (key == "qnn_context_cache_enable") { - if (value != "1") { - ORT_THROW("Set to 1 to enable qnn_context_cache_enable."); - } - } else if (key == "qnn_context_cache_path") { - // no validation } else if (key == "profiling_level") { std::set supported_profiling_level = {"off", "basic", "detailed"}; if (supported_profiling_level.find(value) == supported_profiling_level.end()) { @@ -519,8 +513,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } } else { - ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', -'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', + ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', +'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index abec16f787895..640bfeae84700 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -367,8 +367,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } } else { - ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', -'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', + ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', +'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); } From aacab16d18ac194fcae8b159f122922411b4a8b8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Tue, 19 Dec 2023 00:19:54 -0800 Subject: [PATCH 06/28] Update test code to reflect the changes which move provider options to session options --- .../onnxruntime_session_options_config_keys.h | 14 ++-- .../qnn/builder/onnx_ctx_model_helper.cc | 4 +- onnxruntime/core/session/inference_session.cc | 2 +- .../test/framework/session_state_test.cc | 19 +++--- .../test/providers/qnn/qnn_basic_test.cc | 37 +++++++--- .../test/providers/qnn/qnn_test_utils.cc | 20 ++++-- .../test/providers/qnn/qnn_test_utils.h | 28 +++++--- .../test/providers/qnn/simple_op_htp_test.cc | 68 ++++++++++++++----- onnxruntime/test/util/default_providers.cc | 5 +- .../test/util/include/default_providers.h | 3 +- 10 files changed, 137 insertions(+), 63 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c0f503ea02821..74865e8ed54d8 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -240,16 +240,20 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMin // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. -static const char* const kOrtSessionOptionEpContextEnable = "ep.ep_context_enable"; +static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; // Specify the file path for the Onnx model which has EP context. // Default to original_file_name_ctx.onnx if not specified -static const char* const kOrtSessionOptionEpContextFilePath = "ep.ep_context_file_path"; +static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.ep_context_embed_mode"; +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; -// Dump the model after graph partitioning to file "partitioned_graph.onnx". -static const char* const kDumpPartitionedGraph = "session.dump_partitioned_graph"; \ No newline at end of file +// This option will dump out the model to assist debugging any issues with graph partitioning, +// and is primarily intended for developer usage. +// +// Default is off. Set to "1" to enable. +// The model will be saved to filename "partitioned_graph.onnx". +static const char* const kDebugGraphPartitioning = "session.debug_graph_partitioning"; \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 08f862d2b4dcb..774c8c9e821f9 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -39,7 +39,7 @@ Status GetMainContextNode(const std::vector>& qnn_models) { main_context_pos = -1; - for (int i = 0; i < fused_nodes_and_graphs.size(); ++i) { + for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) { const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); const auto& ep_context_node = graph_viewer.Nodes().begin(); ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); @@ -48,7 +48,7 @@ Status GetMainContextNode(const std::vector(0)); if (1 == is_main_context) { - main_context_pos = i; + main_context_pos = static_cast(i); } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index fda67b3685bd9..1a26d689f3921 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1199,7 +1199,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } - bool dump_partitioned_graph = session_options_.config_options.GetConfigOrDefault(kDumpPartitionedGraph, "0") == "1"; + bool dump_partitioned_graph = session_options_.config_options.GetConfigOrDefault(kDebugGraphPartitioning, "0") == "1"; if (dump_partitioned_graph) { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, "partitioned_graph.onnx")); } diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index a6aa57dc4a81a..6c1681e627f98 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -170,15 +170,16 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( - partitioner.Partition(graph, session_state.GetMutableFuncMgr(), - [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, - const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); - return layout_transformation::TransformLayoutForEP( - graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); - }, - sess_options.config_options, - DefaultLoggingManager().DefaultLogger())); + partitioner.Partition( + graph, session_state.GetMutableFuncMgr(), + [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, + const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { + AllocatorPtr cpu_allocator = std::make_shared(); + return layout_transformation::TransformLayoutForEP( + graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index e30c79eca3a13..391d7bebc9589 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -375,17 +375,36 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - provider_options["qnn_context_cache_enable"] = "1"; + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildCastAddTestCase()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - provider_options["qnn_context_cache_path"] = context_binary_file; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - RunQnnModelTest(BuildCastAddTestCase(), - provider_options, - 13, // opset - ExpectedEPNodeAssignment::All, - 1e-5f, - logging::Severity::kERROR, - false); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 4c38109d30371..f5ebe45a07912 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -13,6 +13,7 @@ #include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/graph/graph.h" +#include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { namespace test { @@ -106,24 +107,31 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov TryEnableQNNSaver(provider_options); RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), - helper.feeds_, verification_params, {}, verify_outputs); + helper.feeds_, verification_params, + {}, verify_outputs); } void InferenceModel(const std::string& model_data, const char* log_id, - std::unique_ptr execution_provider, + const ProviderOptions& provider_options, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_vals) { + std::vector& output_vals, + bool is_qnn_ep, + const std::unordered_map& session_option_pairs) { SessionOptions so; so.session_logid = log_id; + for (auto key_value : session_option_pairs) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(key_value.first.c_str(), key_value.second.c_str())); + } RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; std::string provider_type = kCpuExecutionProvider; - if (execution_provider) { - provider_type = execution_provider->Type(); - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider))); + if (is_qnn_ep) { + auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); + provider_type = qnn_ep->Type(); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(qnn_ep))); } ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); ASSERT_STATUS_OK(session_object.Initialize()); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 9ec0985e8130c..bfe5bab318313 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -220,15 +220,19 @@ inline QuantParams GetTestInputQuantParams(const TestInputDef& inp * * \param model_data The serialized ONNX model to inference. * \param log_id The logger ID. - * \param execution_provider The EP on which to run the model. Set to nullptr for CPU EP. + * \param provider_options provider options key value pair. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. * \param feeds The input feeds. * \param output_vals Initialized to the inference results. + * \param is_qnn_ep Ture: QNN EP is used. False: CPU EP is used (default). + * \param session_option_pairs extra session options. */ void InferenceModel(const std::string& model_data, const char* log_id, - std::unique_ptr execution_provider, + const ProviderOptions& provider_options, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_vals); + std::vector& output_vals, + bool is_qnn_ep = false, + const std::unordered_map& session_option_pairs = {}); /** * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies @@ -287,7 +291,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ExpectedEPNodeAssignment expected_ep_assignment, QDQTolerance tolerance = QDQTolerance(), logging::Severity log_severity = logging::Severity::kERROR, - const std::string& qnn_ctx_model_path = "") { + const std::string& qnn_ctx_model_path = "", + const std::unordered_map& session_option_pairs = {}) { // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; @@ -307,7 +312,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; - InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, + InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, f32_helper.feeds_, cpu_f32_outputs); ASSERT_FALSE(cpu_f32_outputs.empty()); @@ -344,7 +349,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - // Run QDQ model on QNN EP and collect outputs. + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { @@ -355,18 +360,19 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe std::string qnn_ctx_model_data; model_proto.SerializeToString(&qnn_ctx_model_data); // Run QNN context cache model on QNN EP and collect outputs. - InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); + InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); } else { // Run QDQ model on QNN EP and collect outputs. - InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); + // Only need to apply the extra session options to this QDQ model inference on QNN EP + InferenceModel(qdq_model_data, "qdq_model_logger", qnn_options, expected_ep_assignment, + qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); } if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. std::vector cpu_qdq_outputs; - InferenceModel(qdq_model_data, "qdq_model_logger", nullptr, ExpectedEPNodeAssignment::All, + InferenceModel(qdq_model_data, "qdq_model_logger", {}, ExpectedEPNodeAssignment::All, qdq_helper.feeds_, cpu_qdq_outputs); ASSERT_EQ(cpu_qdq_outputs.size(), num_outputs); ASSERT_EQ(qnn_qdq_outputs.size(), num_outputs); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 39733f50482a6..8ff65c08e8633 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -8,6 +8,7 @@ #include #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -733,9 +734,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - provider_options["qnn_context_cache_enable"] = "1"; const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - provider_options["qnn_context_cache_path"] = context_binary_file; + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -746,7 +749,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); @@ -756,7 +763,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // 3rd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), @@ -780,10 +791,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - provider_options["qnn_context_cache_enable"] = "1"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - provider_options["qnn_context_cache_path"] = context_binary_file; - provider_options["qnn_context_embed_mode"] = "0"; + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -794,7 +806,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); @@ -806,7 +822,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), @@ -829,10 +849,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - provider_options["qnn_context_cache_enable"] = "1"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - provider_options["qnn_context_cache_path"] = context_binary_file; - provider_options["qnn_context_embed_mode"] = "0"; + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -843,7 +864,11 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); @@ -886,9 +911,10 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - provider_options["qnn_context_cache_enable"] = "1"; const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - provider_options["qnn_context_cache_path"] = context_binary_file; + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -900,7 +926,11 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); @@ -910,7 +940,11 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, 14, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); // 3rd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 65646a7286719..1e537b1620ad4 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -242,9 +242,10 @@ std::unique_ptr DefaultQnnExecutionProvider() { #endif } -std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options) { +std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, + const SessionOptions* session_options) { #ifdef USE_QNN - return QNNProviderFactoryCreator::Create(options, nullptr)->CreateProvider(); + return QNNProviderFactoryCreator::Create(options, session_options)->CreateProvider(); #else ORT_UNUSED_PARAMETER(options); return nullptr; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 1325f7aa43dbb..68c2f13a378c7 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -52,7 +52,8 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab std::unique_ptr DefaultCoreMLExecutionProvider(); std::unique_ptr DefaultSnpeExecutionProvider(); std::unique_ptr DefaultQnnExecutionProvider(); -std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options); +std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, + const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); From 22b4c93c80e398b44f9f3eb402040af55cffb429 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Tue, 26 Dec 2023 21:50:10 -0800 Subject: [PATCH 07/28] Fix Linux build --- onnxruntime/core/providers/qnn/qnn_execution_provider.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 224842546e789..ddc04c1051efc 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -9,6 +9,7 @@ #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_graph_configs_helper.h" +#include "core/graph/model.h" namespace onnxruntime { From de53da11c8d1a40cbf36c41dfd9ae83979f0f508 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 28 Dec 2023 22:05:52 -0800 Subject: [PATCH 08/28] fix some build issues --- onnxruntime/core/framework/graph_partitioner.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index f6f1a1e6aba93..7d0e59ffb5033 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -956,6 +956,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, kernel_registry_mgr_, ep_context_enabled, ep_context_path, logger)); #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { From c3883b152ff7a3224a21ccef8b5bdd0da59a5446 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 17 Jan 2024 23:02:14 -0800 Subject: [PATCH 09/28] Set inputs outputs explicitly to make sure the order is same as the user model. --- .../core/framework/graph_partitioner.cc | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 7d0e59ffb5033..15c5ee6e188db 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -675,6 +675,34 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers graph.DomainToVersionMap(), {}, logger); auto& ep_graph = ep_context_model.MainGraph(); ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + int i = 0; + std::vector ep_graph_inputs; + ep_graph_inputs.resize(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs[i] = &ep_graph_input_arg; + ++i; + } + + i = 0; + std::vector ep_graph_outputs; + ep_graph_outputs.resize(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs[i] = &ep_graph_output_arg; + ++i; + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + for (const auto& node : graph.Nodes()) { // the fused node and EPContext node has same node name auto ep_context_node = get_ep_context_node(node.Name()); @@ -685,6 +713,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ep_graph.AddNode(node); } } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); } From 30c1ed7c2afb207e4ac296770232ef5225f4018e Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 19 Jan 2024 16:21:51 -0800 Subject: [PATCH 10/28] resolve conflict --- cmake/external/emsdk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index a896e3d066448..4e2496141eda1 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 +Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c From 55d10b22a5dce9b6ddbca8035ce2be904ddf46e8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Sat, 20 Jan 2024 19:24:12 -0800 Subject: [PATCH 11/28] resolved merge conflicts --- .../providers/qnn/qnn_execution_provider.cc | 108 +++++++++++------- 1 file changed, 67 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 56eb1f4f59f33..4e29e003d4628 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -134,6 +134,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } static const std::string BACKEND_PATH = "backend_path"; @@ -266,14 +267,20 @@ std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool load_from_cached_context, + bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; // Enable Qnn context cache requires the whole graph partitioned to Qnn EP // Blindly filter in all nodes if context cache is enabled - if (load_from_cached_context) { + if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { - supported_nodes.insert(&node); + if (qnn::EPCONTEXT_OP == node.OpType()) { + LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() + << "] name: [" << node.Name() + << "] Operator type: [EPContext" + << "] index: [" << node.Index() << "]"; + supported_nodes.insert(&node); + } } return supported_nodes; } @@ -359,7 +366,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; - bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); + bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); if (is_qnn_ctx_model) { load_from_cached_context = true; } @@ -372,7 +379,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer context_cache_path); } - // Load from cached context will load the QnnSystem lib and skip the Qnn context creation + // It will load the QnnSystem lib if load_from_cached_context=true, and + // delay the Qnn context creation to Compile() using the cached context auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); @@ -391,7 +399,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), - load_from_cached_context, logger); + is_qnn_ctx_model, logger); // Helper function that returns a string that lists all unsupported nodes. // Ex: { name: mul_123, type: Mul }, {}, ... @@ -444,7 +452,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (partition && partition->sub_graph) { nodes_in_partition = partition->sub_graph->nodes.size(); - if (nodes_in_partition == 1) { + if (nodes_in_partition == 1 && !load_from_cached_context) { const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); if (!node) { @@ -565,52 +573,70 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); - Node& fused_node = fused_nodes_and_graphs[0].fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - bool is_qnn_ctx_model = false; - ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); onnxruntime::PathString context_cache_path; - bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - const std::string& model_name = graph_viewer.GetGraph().Name(); - const std::string& model_description = graph_viewer.GetGraph().Description(); - const std::string& graph_meta_id = fused_node.Name(); - if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, - model_name, - model_description, - graph_meta_id, - logger)); + bool is_ctx_file_exist = false; + if (!is_qnn_ctx_model) { + const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); + is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); + if (is_ctx_file_exist) { + ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(fused_nodes_and_graphs, + context_cache_path, + graph_viewer_0.GetGraph().Name(), + graph_viewer_0.GetGraph().Description(), + logger)); + } } if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); - std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, - context_cache_path, - is_qnn_ctx_model, - is_ctx_file_exist, - qnn_backend_manager_.get(), - *(qnn_model.get()), - logger)); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + // Table, the node name is not same with the graph_meta_id + std::unordered_map> qnn_models; - // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] - // the name here should be same with context->node_name in compute_info - qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + if (is_qnn_ctx_model) { + int main_context_pos = -1; + ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), + logger, main_context_pos, qnn_models)); + + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, + context_cache_path, + qnn_backend_manager_.get(), + qnn_models)); + } else { + ORT_RETURN_IF_ERROR(qnn::LoadContextFromOnnxModel(fused_nodes_and_graphs, context_cache_path, + qnn_backend_manager_.get(), logger, qnn_models)); + } + + for (auto fused_node_and_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& fused_node = fused_node_and_graph.fused_node; + const std::string& graph_meta_id = fused_node.Name(); + std::string key = is_qnn_ctx_model ? ep_context_node->Name() : graph_meta_id; + ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); + auto qnn_model = std::move(qnn_models[key]); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + + // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] + // the name here must be same with context->node_name in compute_info + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + + ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + } - ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); - if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + // Generate QNN context model if it's QDQ model + context_cache_enabled=true + not exist already + if (!is_qnn_ctx_model && context_cache_enabled_ && !is_ctx_file_exist) { + // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); From ce3c64f4b29969a6c31b666e4c7c40bc91bda191 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Sat, 20 Jan 2024 19:29:31 -0800 Subject: [PATCH 12/28] resolve merge conflicts --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 4e29e003d4628..b4dd7c61642e9 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -270,8 +270,7 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Enable Qnn context cache requires the whole graph partitioned to Qnn EP - // Blindly filter in all nodes if context cache is enabled + // Filter in the EPContext node if its QNN Context model if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { if (qnn::EPCONTEXT_OP == node.OpType()) { @@ -483,7 +482,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Print list of unsupported nodes to the ERROR logger if the CPU EP // has been disabled for this inference session. - if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + if (!load_from_cached_context && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); } From 8c55f1942dd41f35feea0fdc1465a08669c57d74 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 22 Jan 2024 14:10:09 -0800 Subject: [PATCH 13/28] remove the validation mode --- .../qnn/builder/onnx_ctx_model_helper.cc | 52 ++--------------- .../qnn/builder/onnx_ctx_model_helper.h | 12 ---- .../providers/qnn/qnn_execution_provider.cc | 57 +++++++------------ 3 files changed, 25 insertions(+), 96 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index ce3f1184601fa..884b2368ba11d 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -13,10 +13,14 @@ namespace onnxruntime { namespace qnn { bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { - // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type + // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider. for (const auto& node : graph_viewer.Nodes()) { if (EPCONTEXT_OP == node.OpType()) { - return true; + NodeAttrHelper node_helper(node); + const std::string cache_source = node_helper.Get(SOURCE, ""); + if (cache_source == kQnnExecutionProvider || cache_source == "QNN") { + return true; + } } } return false; @@ -56,50 +60,6 @@ Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, - std::unordered_map>& qnn_models) { - for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { - const Node& fused_node = fused_node_and_graph.fused_node; - qnn_models.emplace(fused_node.Name(), - std::make_unique(logger, qnn_backend_manager)); - } - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); - const auto& graph = GraphViewer(model->MainGraph()); - - for (const auto& ep_context_node : graph.Nodes()) { - if (EPCONTEXT_OP != ep_context_node.OpType()) { - continue; - } - NodeAttrHelper node_helper(ep_context_node); - int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); - if (1 == is_main_context) { - return GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models); - } - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to find EPContext node with main_context=1."); -} - -Status LoadContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, - std::unordered_map>& qnn_models) { - Status status = GetContextFromOnnxModel(fused_nodes_and_graphs, ctx_onnx_model_path, qnn_backend_manager, logger, qnn_models); - - // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model - if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); - } - - return Status::OK(); -} - Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index de2f0bd702924..80252043e414d 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -57,18 +57,6 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, QnnBackendManager* qnn_backend_manager, std::unordered_map>& qnn_models); -Status GetContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, - std::unordered_map>& qnn_models); - -Status LoadContextFromOnnxModel(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, - std::unordered_map>& qnn_models); - Status ValidateWithContextFile(const std::vector& fused_nodes_and_graphs, const onnxruntime::PathString& context_cache_path, const std::string& model_name, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b4dd7c61642e9..e946846d91181 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -366,21 +366,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto& logger = *GetLogger(); bool load_from_cached_context = false; bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); - if (is_qnn_ctx_model) { - load_from_cached_context = true; - } - - // This is for case: QDQ model + Onnx Qnn context cache model - if (context_cache_enabled_ && !is_qnn_ctx_model) { - onnxruntime::PathString context_cache_path; - load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - } - // It will load the QnnSystem lib if load_from_cached_context=true, and - // delay the Qnn context creation to Compile() using the cached context - auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); + // It will load the QnnSystem lib if is_qnn_ctx_model=true, and + // delay the Qnn context creation to Compile() using the cached context binary + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -582,41 +571,33 @@ Status QNNExecutionProvider::Compile(const std::vector& fused is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, graph_viewer_0.ModelPath().ToPathString(), context_cache_path); - if (is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(fused_nodes_and_graphs, - context_cache_path, - graph_viewer_0.GetGraph().Name(), - graph_viewer_0.GetGraph().Description(), - logger)); - } + ORT_RETURN_IF(is_ctx_file_exist, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. \ + Please remove the EP context model manually if you want to re-generate it."); } - if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - // Table, the node name is not same with the graph_meta_id + if (is_qnn_ctx_model) { + // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model + // for this session (created from an EP context model), the graph_meta_id is new std::unordered_map> qnn_models; - if (is_qnn_ctx_model) { - int main_context_pos = -1; - ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), - logger, main_context_pos, qnn_models)); + int main_context_pos = -1; + ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), + logger, main_context_pos, qnn_models)); - const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); - // Create QNN context from the cached binary, deserialize the QNN graph from the binary - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_cache_path, - qnn_backend_manager_.get(), - qnn_models)); - } else { - ORT_RETURN_IF_ERROR(qnn::LoadContextFromOnnxModel(fused_nodes_and_graphs, context_cache_path, - qnn_backend_manager_.get(), logger, qnn_models)); - } + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, + context_cache_path, + qnn_backend_manager_.get(), + qnn_models)); for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); const auto& ep_context_node = graph_viewer.Nodes().begin(); const Node& fused_node = fused_node_and_graph.fused_node; const std::string& graph_meta_id = fused_node.Name(); - std::string key = is_qnn_ctx_model ? ep_context_node->Name() : graph_meta_id; + std::string key = ep_context_node->Name(); ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); auto qnn_model = std::move(qnn_models[key]); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); From e7c0827793f9a6750e9a92811bc24d9626575ddc Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 22 Jan 2024 14:15:35 -0800 Subject: [PATCH 14/28] clean up some not used code --- .../qnn/builder/onnx_ctx_model_helper.cc | 85 ------------------- .../qnn/builder/onnx_ctx_model_helper.h | 13 --- 2 files changed, 98 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 884b2368ba11d..86e1c2d241007 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -156,31 +156,6 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::vector& graph_partition_names, - std::string& cache_source, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); - const auto& graph = GraphViewer(model->MainGraph()); - model_name = graph.Name(); - model_description = graph.Description(); - - for (const auto& ep_context_node : graph.Nodes()) { - if (EPCONTEXT_OP != ep_context_node.OpType()) { - continue; - } - NodeAttrHelper node_helper(ep_context_node); - cache_source = node_helper.Get(SOURCE, ""); - graph_partition_names.push_back(node_helper.Get(PARTITION_NAME, "")); - } - - return Status::OK(); -} - bool IsContextCacheFileExists(const std::string& customer_context_cache_path, const onnxruntime::PathString& model_pathstring, onnxruntime::PathString& context_cache_path) { @@ -194,66 +169,6 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status ValidateWithContextFile(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const logging::Logger& logger) { - std::string model_name_from_ctx_cache; - std::string model_description_from_ctx_cache; - std::vector graph_partition_names; - std::string cache_source; - - auto status = GetMetadataFromEpContextModel(context_cache_path, - model_name_from_ctx_cache, - model_description_from_ctx_cache, - graph_partition_names, - cache_source, - logger); - if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContext model."); - } - - // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT - if (cache_source != kQnnExecutionProvider) { - LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; - return Status::OK(); - } - - bool partition_names_matched = true; - for (const auto& fused_node_graph : fused_nodes_and_graphs) { - const Node& fused_node = fused_node_graph.fused_node; - const std::string& graph_meta_id = fused_node.Name(); - bool name_found = false; - for (auto partition_name_from_ctx : graph_partition_names) { - if (partition_name_from_ctx == graph_meta_id) { - name_found = true; - break; - } - } - - if (!name_found) { - LOGS(logger, ERROR) << "Partition meta_id not found from any EPContext node: " << graph_meta_id; - partition_names_matched = false; - break; - } - } - - if (model_name != model_name_from_ctx_cache || - model_description != model_description_from_ctx_cache || - !partition_names_matched) { - std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", - model_name, " ", model_description, - " vs epcontext: ", - model_name_from_ctx_cache, " ", - model_description_from_ctx_cache, - " or the partition name not match."); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); - } - - return Status::OK(); -} - Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 80252043e414d..5ad375fb9dcb5 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -57,19 +57,6 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, QnnBackendManager* qnn_backend_manager, std::unordered_map>& qnn_models); -Status ValidateWithContextFile(const std::vector& fused_nodes_and_graphs, - const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const logging::Logger& logger); - -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::vector& graph_partition_names, - std::string& cache_source, - const logging::Logger& logger); - Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, From d3feaa41905b9dfb3b5479bc6c7a80e952b5dc5e Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 22 Jan 2024 14:18:50 -0800 Subject: [PATCH 15/28] renaming --- .../qnn/builder/onnx_ctx_model_helper.cc | 18 +++++++++--------- .../qnn/builder/onnx_ctx_model_helper.h | 18 +++++++++--------- .../providers/qnn/qnn_execution_provider.cc | 18 +++++++++--------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 86e1c2d241007..351915247d7be 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -169,15 +169,15 @@ bool IsContextCacheFileExists(const std::string& customer_context_cache_path, return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger) { +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger) { auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 5ad375fb9dcb5..3a3799bd21bbf 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -57,14 +57,14 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, QnnBackendManager* qnn_backend_manager, std::unordered_map>& qnn_models); -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger); +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index e946846d91181..621164dabd33a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -620,15 +620,15 @@ Status QNNExecutionProvider::Compile(const std::vector& fused uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), - context_buffer.get(), - buffer_size, - qnn_backend_manager_->GetSdkVersion(), - fused_nodes_and_graphs, - qnn_models_, - context_cache_path, - qnn_context_embed_mode_, - logger)); + ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), + context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + context_cache_path, + qnn_context_embed_mode_, + logger)); } return Status::OK(); } From 33516cdf651fdb144835197d08290b53d89e7450 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 22 Jan 2024 20:30:58 -0800 Subject: [PATCH 16/28] Update tests --- .../providers/qnn/qnn_execution_provider.cc | 6 +- .../test/providers/qnn/simple_op_htp_test.cc | 62 ++++++------------- 2 files changed, 21 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 621164dabd33a..cdb30286d4489 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -566,14 +566,14 @@ Status QNNExecutionProvider::Compile(const std::vector& fused onnxruntime::PathString context_cache_path; bool is_ctx_file_exist = false; - if (!is_qnn_ctx_model) { + if (!is_qnn_ctx_model && context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, graph_viewer_0.ModelPath().ToPathString(), context_cache_path); ORT_RETURN_IF(is_ctx_file_exist, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. \ - Please remove the EP context model manually if you want to re-generate it."); + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); } if (is_qnn_ctx_model) { diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 1e938ae9e334b..7403a7c96e8e2 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -725,8 +725,7 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { // Run QDQ model on HTP 3 times // 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model +// 2nd run directly loads and run from Qnn context cache model TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { ProviderOptions provider_options; #if defined(_WIN32) @@ -735,6 +734,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif const std::string context_binary_file = "./qnn_context_binary_test.onnx"; + std::remove(context_binary_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -758,18 +758,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model + // 2nd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -784,8 +773,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { // Run QDQ model on HTP 3 times // 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file -// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { ProviderOptions provider_options; #if defined(_WIN32) @@ -794,11 +782,16 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + std::remove(context_binary_file.c_str()); + std::remove(qnn_ctx_bin.c_str()); + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -817,21 +810,9 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // Check the Qnn context cache binary file is generated - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); - // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file + // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -857,6 +838,10 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { provider_options["backend_path"] = "libQnnHtp.so"; #endif const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); @@ -880,7 +865,6 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // Check the Qnn context cache binary file is generated - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; EXPECT_TRUE(std::filesystem::exists(context_bin)); // Delete the Qnn context cache binary file EXPECT_TRUE(std::filesystem::remove(context_bin)); @@ -1041,8 +1025,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { // Run QDQ model on HTP with 2 inputs // 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model +// 2nd run directly loads and run from Qnn context cache model TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { ProviderOptions provider_options; #if defined(_WIN32) @@ -1051,6 +1034,8 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); @@ -1074,18 +1059,7 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model + // 2nd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, From 445bc1bef42087a3013d4796181adb504d8b0837 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 09:04:45 -0800 Subject: [PATCH 17/28] fix the issue relate to initializer handling --- .../core/framework/graph_partitioner.cc | 112 +++++++++--------- .../qnn/builder/onnx_ctx_model_helper.cc | 10 +- 2 files changed, 63 insertions(+), 59 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 07b465c80745a..90ee8a46f66a9 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); } + if (all_ep_context_nodes.size() < 1) { + return Status::OK(); + } + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { for (auto& node : all_ep_context_nodes) { if (node_name == node->Name()) { @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers onnxruntime::PathString context_cache_path; PathString model_pathstring = graph.ModelPath().ToPathString(); - if (all_ep_context_nodes.size() > 0) { - if (!ep_context_path.empty()) { - context_cache_path = ToPathString(ep_context_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { #ifdef _WIN32 - std::wifstream fs(context_cache_path); + std::wifstream fs(context_cache_path); #else - std::ifstream fs(context_cache_path); + std::ifstream fs(context_cache_path); #endif - ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); - } + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph.DomainToVersionMap(), {}, logger); - auto& ep_graph = ep_context_model.MainGraph(); - ep_graph.SetDescription(graph.Description()); - - // Set inputs outputs explicitly to make sure the order is same as the user model. - auto inputs = graph.GetInputs(); - auto outputs = graph.GetOutputs(); - - InlinedVector ep_graph_inputs; - ep_graph_inputs.reserve(inputs.size()); - for (auto& input : inputs) { - auto input_arg = graph.GetNodeArg(input->Name()); - auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); - ep_graph_inputs.push_back(&ep_graph_input_arg); - } + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); - InlinedVector ep_graph_outputs; - ep_graph_outputs.reserve(outputs.size()); - for (auto& output : outputs) { - auto output_arg = graph.GetNodeArg(output->Name()); - auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - ep_graph_outputs.push_back(&ep_graph_output_arg); - } + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); - ep_graph.SetInputs(ep_graph_inputs); - ep_graph.SetOutputs(ep_graph_outputs); + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } - for (const auto& node : graph.Nodes()) { - // the fused node and EPContext node has same node name - auto ep_context_node = get_ep_context_node(node.Name()); - // Use EpContext node created by the EPs if name matched, otherwise use node from original model - if (ep_context_node.first) { - ep_graph.AddNode(*ep_context_node.second); - } else { - ep_graph.AddNode(node); - } - } + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } - // handle initializers - for (const auto& input : graph.GetInputsIncludingInitializers()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - // There initializer could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { - ep_graph.AddInitializedTensor(*initializer); - } - } + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); } + } - ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + // handle initializers + for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { + ep_graph.AddInitializedTensor(*initialized_tensor.second); + } } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 351915247d7be..2ff0714292fbc 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -17,8 +17,14 @@ bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { for (const auto& node : graph_viewer.Nodes()) { if (EPCONTEXT_OP == node.OpType()) { NodeAttrHelper node_helper(node); - const std::string cache_source = node_helper.Get(SOURCE, ""); - if (cache_source == kQnnExecutionProvider || cache_source == "QNN") { + std::string cache_source = node_helper.Get(SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") { return true; } } From 9c7bdfc54c0674b05d9e42b1a63c645e72a12a33 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 11:10:13 -0800 Subject: [PATCH 18/28] Move QNN context cache related tests to a separate file --- .../test/providers/qnn/qnn_basic_test.cc | 88 --- .../test/providers/qnn/qnn_ep_context_test.cc | 502 ++++++++++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 349 ------------ 3 files changed, 502 insertions(+), 437 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/qnn_ep_context_test.cc diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index bc40682cf87b7..d6a8e318b8c43 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -559,94 +559,6 @@ static GetTestModelFn BuildCastAddTestCase() { }; } -// Test that models with 2 inputs which has different data type can still generate the context binary -TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - BuildCastAddTestCase()(helper); - helper.SetGraphOutputs(); - ASSERT_STATUS_OK(model.MainGraph().Resolve()); - - // Serialize the model to a string. - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Generate context cache model from the ONNX models with 2 inputs. -// The generated model should have same input order. -// The input ONNX model is created in the way that the model inputs order -// is different with the order in the graph (topological order). -// It cause issue if the generated model doesn't set the inputs/outputs explicitly. -TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); - auto inputs = model->MainGraph().GetInputs(); - EXPECT_TRUE(inputs.size() == 2); - EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); - EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) // the value pair(1, 0.00392156886) at index #1 don't match, // which is -0.996078 from 1 diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc new file mode 100644 index 0000000000000..513be838f63e7 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -0,0 +1,502 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU +#include "core/session/inference_session.h" + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") + +// in test_main.cc +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Create a model with Case + Add (quantized) +// cast_input -> Cast -> Q -> DQ \ +// Add -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildCastAddTestCase() { + return [](ModelTestBuilder& builder) { + // Creat Cast node int32 -> float32 + NodeArg* cast_input = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); + + auto* cast_output = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output}); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + + // Create Add node + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add_input1_qdq = AddQDQNodePair(builder, cast_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + auto* add_input2_qdq = AddQDQNodePair(builder, add_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output}); + + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add_output, q_parameter.scale, q_parameter.zero_point); + }; +} + +// Test that models with 2 inputs which has different data type can still generate the context binary +TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildCastAddTestCase()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Generate context cache model from the ONNX models with 2 inputs. +// The generated model should have same input order. +// The input ONNX model is created in the way that the model inputs order +// is different with the order in the graph (topological order). +// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto inputs = model->MainGraph().GetInputs(); + EXPECT_TRUE(inputs.size() == 2); + EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); + EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + std::remove(context_binary_file.c_str()); + std::remove(qnn_ctx_bin.c_str()); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); + + // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); +} + +// Run QDQ model on HTP 2 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH +TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + // Delete the Qnn context cache binary file + EXPECT_TRUE(std::filesystem::remove(context_bin)); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + onnx::ModelProto model_proto; + onnxruntime::Model qnn_ctx_model; + // Load the QNN context cache model from path specified + ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); + std::string qnn_ctx_model_data; + model_proto.SerializeToString(&qnn_ctx_model_data); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + std::string provider_type = kCpuExecutionProvider; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + // The .. in the path will cause INVALID_GRAPH + ep_context_node.AddAttribute("ep_cache_context", external_bin_path); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNN"); + helper.SetGraphOutputs(); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + return model_data; +} + +// Create a model with EPContext node. Set the node property ep_cache_context has ".." +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +#if defined(_WIN32) + std::string external_ctx_bin_path = "D:/qnn_context.bin"; +#else + std::string external_ctx_bin_path = "/data/qnn_context.bin"; +#endif + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to empty string +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Run QDQ model on HTP with 2 inputs +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); + const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Add"; + + // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 7403a7c96e8e2..2f3b0e84a123e 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -723,355 +723,6 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { true); // Use com.microsoft domain for Q/DQ ops } -// Run QDQ model on HTP 3 times -// 1st run will generate the Qnn context cache onnx file -// 2nd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - std::remove(context_binary_file.c_str()); - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Run QDQ model on HTP 3 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file -TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - std::remove(context_binary_file.c_str()); - std::remove(qnn_ctx_bin.c_str()); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); - - // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); - ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); -} - -// Run QDQ model on HTP 2 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH -TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - std::remove(context_binary_file.c_str()); - std::remove(context_bin.string().c_str()); - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_bin)); - // Delete the Qnn context cache binary file - EXPECT_TRUE(std::filesystem::remove(context_bin)); - - // loads and run from Onnx skeleton file + Qnn context cache binary file - onnx::ModelProto model_proto; - onnxruntime::Model qnn_ctx_model; - // Load the QNN context cache model from path specified - ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); - std::string qnn_ctx_model_data; - model_proto.SerializeToString(&qnn_ctx_model_data); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - std::string provider_type = kCpuExecutionProvider; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { - const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; - auto& logging_manager = DefaultLoggingManager(); - onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - std::vector shape = {2, 3}; - NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); - auto* graph_output = helper.MakeOutput(shape); - Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); - ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH - ep_context_node.AddAttribute("ep_cache_context", external_bin_path); - ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); - ep_context_node.AddAttribute("source", "QNN"); - helper.SetGraphOutputs(); - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - return model_data; -} - -// Create a model with EPContext node. Set the node property ep_cache_context has ".." -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context has absolute path -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { -#if defined(_WIN32) - std::string external_ctx_bin_path = "D:/qnn_context.bin"; -#else - std::string external_ctx_bin_path = "/data/qnn_context.bin"; -#endif - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to empty string -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Run QDQ model on HTP with 2 inputs -// 1st run will generate the Qnn context cache onnx file -// 2nd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_binary_file.c_str()); - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); - const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Add"; - - // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { ProviderOptions provider_options; From 3dfd94ba7c3f9ead417e248c476773ae4fd2018e Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 11:14:38 -0800 Subject: [PATCH 19/28] rename some tests --- onnxruntime/test/providers/qnn/qnn_ep_context_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 513be838f63e7..859fe4fdba4ae 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -150,7 +150,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { // Run QDQ model on HTP 3 times // 1st run will generate the Qnn context cache onnx file // 2nd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -198,7 +198,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { // Run QDQ model on HTP 3 times // 1st run will generate the Onnx skeleton file + Qnn context cache binary file // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file -TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -254,7 +254,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { // Run QDQ model on HTP 2 times // 1st run will generate the Onnx skeleton file + Qnn context cache binary file // Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH -TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -450,7 +450,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { // Run QDQ model on HTP with 2 inputs // 1st run will generate the Qnn context cache onnx file // 2nd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { +TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; From 3b8e8790f4532b9e5521fb7f6857996528369f75 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 16:05:43 -0800 Subject: [PATCH 20/28] Add UT to verify the multi-partition support --- .../test/providers/qnn/qnn_ep_context_test.cc | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 859fe4fdba4ae..3fd499e470f85 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -28,6 +28,110 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Create a model with Case + Add (quantized) +// input1 -> Add -> Q -> DQ \ +// Add -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildGraphWithQAndNonQ() { + return [](ModelTestBuilder& builder) { + // Creat non-quantized Add node + NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); + NodeArg* add1_ini_input1 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + + auto* add1_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input1, add1_ini_input1}, {add1_output}); + + // Create quantized Add node2 + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add2_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); + + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + }; +} + +// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 Add node and 1 EPContext node +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildGraphWithQAndNonQ()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.SetLogSeverityLevel(0); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + int ep_context_node_count = 0; + int non_ep_context_node_count = 0; + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + ++ep_context_node_count; + } else { + ++non_ep_context_node_count; + } + } + + ASSERT_EQ(ep_context_node_count, 1); + ASSERT_EQ(non_ep_context_node_count, 1); + + Ort::SessionOptions so2; + // context file path is required if it's non-embed mode and the model is loaded from memroy + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AppendExecutionProvider("QNN", provider_options); + + std::string ctx_model_data; + ctx_model->ToProto().SerializeToString(&ctx_model_data); + Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + // Create a model with Case + Add (quantized) // cast_input -> Cast -> Q -> DQ \ // Add -> Q -> DQ -> output @@ -68,7 +172,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); @@ -90,6 +193,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_binary_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); From ff2c3137f941cd9a9fd2d4cc26aa00e0f0414f57 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 20:15:35 -0800 Subject: [PATCH 21/28] fill some gaps in UT and fix an issue relate to context cache path --- .../qnn/builder/onnx_ctx_model_helper.cc | 21 +++++++---- .../qnn/builder/onnx_ctx_model_helper.h | 7 ++-- .../providers/qnn/qnn_execution_provider.cc | 16 +++++---- .../test/providers/qnn/qnn_ep_context_test.cc | 36 ++++++++++++++----- .../test/providers/qnn/qnn_test_utils.h | 2 +- 5 files changed, 57 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 2ff0714292fbc..4bb7378234187 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -162,14 +162,23 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default +// Figure out the real context cache file path +// return true if context cache file exists +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { + // always try the path set by user first, it's the only way to set it if load model from memory if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } else if (!model_pathstring.empty()) { // model loaded from file + if (is_qnn_ctx_model) { + // it's a context cache model, just use the model path + context_cache_path = model_pathstring; + } else if (!model_pathstring.empty()) { + // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } } return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 3a3799bd21bbf..b1360b4e576fa 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -43,9 +43,10 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index cdb30286d4489..c547e8a06bb65 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -566,16 +566,18 @@ Status QNNExecutionProvider::Compile(const std::vector& fused onnxruntime::PathString context_cache_path; bool is_ctx_file_exist = false; - if (!is_qnn_ctx_model && context_cache_enabled_) { + if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer_0.ModelPath().ToPathString(), - context_cache_path); - ORT_RETURN_IF(is_ctx_file_exist, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", - "Please remove the EP context model manually if you want to re-generate it."); + is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, + context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); } + ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); + if (is_qnn_ctx_model) { // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model // for this session (created from an EP context model), the graph_meta_id is new diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 3fd499e470f85..16bb6abf1e758 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -18,8 +18,6 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; -#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") - // in test_main.cc extern std::unique_ptr ort_env; @@ -94,8 +92,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - so.SetLogSeverityLevel(0); - so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); @@ -232,7 +228,6 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); @@ -309,8 +304,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -340,6 +335,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); + std::unordered_map session_option_pairs2; + // Need to set the context file path since TestQDQModelAccuracy load the model from memory + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), @@ -348,7 +346,29 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_binary_file, + session_option_pairs2); + + // load the model from file + std::vector buffer; + { + std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate); + if (!file) + ORT_THROW("Error reading model"); + buffer.resize(narrow(file.tellg())); + file.seekg(0, std::ios::beg); + if (!file.read(buffer.data(), buffer.size())) + ORT_THROW("Error reading model"); + } + + Ort::SessionOptions so; // No need to set the context file path in so since it's load from file + so.AppendExecutionProvider("QNN", provider_options); +#ifdef _WIN32 + std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#else + std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#endif + Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so); // Clean up ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bfe5bab318313..f4febd99ddae7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe model_proto.SerializeToString(&qnn_ctx_model_data); // Run QNN context cache model on QNN EP and collect outputs. InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); } else { // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP From c8ea83d621d3fc7fd454ca7be96659a4eb777dde Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 21:12:21 -0800 Subject: [PATCH 22/28] add one more UT to dumps the context cache model with 2 EPContext nodes --- .../test/providers/qnn/qnn_ep_context_test.cc | 64 ++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 16bb6abf1e758..7936c01362b8b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -30,14 +30,14 @@ namespace test { // input1 -> Add -> Q -> DQ \ // Add -> Q -> DQ -> output // input2 -> Q -> DQ / -static GetTestModelFn BuildGraphWithQAndNonQ() { - return [](ModelTestBuilder& builder) { - // Creat non-quantized Add node +static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { + return [single_ep_node](ModelTestBuilder& builder) { + // Creat non-quantized Add node1 NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); - NodeArg* add1_ini_input1 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); auto* add1_output = builder.MakeIntermediate(); - builder.AddNode("Add", {input1, add1_ini_input1}, {add1_output}); + builder.AddNode("Add", {input1, add1_ini_input2}, {add1_output}); // Create quantized Add node2 std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; @@ -45,21 +45,39 @@ static GetTestModelFn BuildGraphWithQAndNonQ() { QuantParams q_parameter = GetDataQuantParams(data_range); auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, data)); auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); auto* add2_output = builder.MakeIntermediate(); builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); - // add_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + if (single_ep_node) { + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + } else { + auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + + auto* add3_output = builder.MakeIntermediate(); + builder.AddNode("Add", {add3_input1_qdq, add3_ini_input2}, {add3_output}); + + // Create quantized Add node4 + auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, data)); + auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add4_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add4_input1_qdq, add4_input2_qdq}, {add4_output}); + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add4_output, q_parameter.scale, q_parameter.zero_point); + } }; } -// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary -// The generated Onnx model has 1 Add node and 1 EPContext node -TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { +void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -77,7 +95,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { logging_manager.DefaultLogger()); Graph& graph = model.MainGraph(); ModelTestBuilder helper(graph); - BuildGraphWithQAndNonQ()(helper); + BuildGraphWithQAndNonQ(single_ep_node)(helper); helper.SetGraphOutputs(); ASSERT_STATUS_OK(model.MainGraph().Resolve()); @@ -112,11 +130,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { } } - ASSERT_EQ(ep_context_node_count, 1); - ASSERT_EQ(non_ep_context_node_count, 1); + int expected_node_count = single_ep_node ? 1 : 2; + ASSERT_EQ(ep_context_node_count, expected_node_count); + ASSERT_EQ(non_ep_context_node_count, expected_node_count); Ort::SessionOptions so2; - // context file path is required if it's non-embed mode and the model is loaded from memroy + // context file path is required if it's non-embed mode and the model is loaded from memory so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); @@ -128,6 +147,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 Add node and 1 EPContext node +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { + bool single_ep_node = true; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + + +// Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary +// The generated Onnx model has 2 Add nodes and 1 EPContext nodes +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { + bool single_ep_node = false; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + // Create a model with Case + Add (quantized) // cast_input -> Cast -> Q -> DQ \ // Add -> Q -> DQ -> output From 3dbb95d488efba93777c80e864ecfb3d14304fdd Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 25 Jan 2024 21:19:23 -0800 Subject: [PATCH 23/28] formating --- .../core/providers/qnn/builder/onnx_ctx_model_helper.cc | 2 +- onnxruntime/test/providers/qnn/qnn_ep_context_test.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 4bb7378234187..c2e71081b898e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -171,7 +171,7 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, // always try the path set by user first, it's the only way to set it if load model from memory if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { // model loaded from file + } else if (!model_pathstring.empty()) { // model loaded from file if (is_qnn_ctx_model) { // it's a context cache model, just use the model path context_cache_path = model_pathstring; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 7936c01362b8b..2945778bad5dd 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -154,7 +154,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { QnnContextBinaryMultiPartitionTestBody(single_ep_node); } - // Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary // The generated Onnx model has 2 Add nodes and 1 EPContext nodes TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { @@ -395,7 +394,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { ORT_THROW("Error reading model"); } - Ort::SessionOptions so; // No need to set the context file path in so since it's load from file + Ort::SessionOptions so; // No need to set the context file path in so since it's load from file so.AppendExecutionProvider("QNN", provider_options); #ifdef _WIN32 std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); From 9eb32aa73189ea4c90d33fcfc9d7700d696c185d Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 26 Jan 2024 12:36:15 -0800 Subject: [PATCH 24/28] Use ContribOp FusedMatMul instead of Add op to make sure the float32 node will fallback on CPU EP. Add with float32 runs on Linux with HTP backend simulator, that's why it failed. --- .../test/providers/qnn/qnn_ep_context_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 2945778bad5dd..8fbae36f1a3da 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -33,19 +33,19 @@ namespace test { static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { return [single_ep_node](ModelTestBuilder& builder) { // Creat non-quantized Add node1 - NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); - NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 2}, false, {0, 1, 0, 1})); + NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); auto* add1_output = builder.MakeIntermediate(); - builder.AddNode("Add", {input1, add1_ini_input2}, {add1_output}); + builder.AddNode("FusedMatMul", {input1, add1_ini_input2}, {add1_output}, kMSDomain); // Create quantized Add node2 - std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f}; gsl::span data_range = gsl::make_span(data); QuantParams q_parameter = GetDataQuantParams(data_range); auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, data)); + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); auto* add2_output = builder.MakeIntermediate(); @@ -57,15 +57,15 @@ static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); } else { auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); auto* add3_output = builder.MakeIntermediate(); - builder.AddNode("Add", {add3_input1_qdq, add3_ini_input2}, {add3_output}); + builder.AddNode("FusedMatMul", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); // Create quantized Add node4 auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 3}, true, data)); + NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); auto* add4_output = builder.MakeIntermediate(); From 74c5cefa4332f56d02dd9277e2d4ffd753ce84da Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 31 Jan 2024 16:01:33 -0800 Subject: [PATCH 25/28] update according review comments. --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 5 ++--- onnxruntime/test/providers/qnn/qnn_ep_context_test.cc | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 7030545e9f2d9..2fd443bd5d083 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -416,7 +416,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } const auto& logger = *GetLogger(); - bool load_from_cached_context = false; bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); // It will load the QnnSystem lib if is_qnn_ctx_model=true, and @@ -492,7 +491,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (partition && partition->sub_graph) { nodes_in_partition = partition->sub_graph->nodes.size(); - if (nodes_in_partition == 1 && !load_from_cached_context) { + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); if (!node) { @@ -523,7 +522,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Print list of unsupported nodes to the ERROR logger if the CPU EP // has been disabled for this inference session. - if (!load_from_cached_context && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + if (!is_qnn_ctx_model && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); } diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8fbae36f1a3da..b1f3b52e77553 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -3,11 +3,9 @@ #include #include -#include #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #include "core/session/inference_session.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -28,7 +26,7 @@ namespace test { // Create a model with Case + Add (quantized) // input1 -> Add -> Q -> DQ \ -// Add -> Q -> DQ -> output +// FusedMatMul -> Q -> DQ -> output // input2 -> Q -> DQ / static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { return [single_ep_node](ModelTestBuilder& builder) { From 013717222e323871ab53631b0aadbaefd4dff0d8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 1 Feb 2024 10:25:53 -0800 Subject: [PATCH 26/28] update according review comments --- .../providers/qnn/qnn_execution_provider.cc | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 2fd443bd5d083..a25b79e63fe44 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -322,10 +322,18 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Filter in the EPContext node if its QNN Context model + // Filter in the EPContext node for QNN if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { - if (qnn::EPCONTEXT_OP == node.OpType()) { + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(qnn::SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() << "] name: [" << node.Name() << "] Operator type: [EPContext" @@ -484,34 +492,36 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. // We also count the number of supported nodes in all valid partitions. - for (auto& partition : partitions) { - bool is_valid_partition = true; - size_t nodes_in_partition = 0; - - if (partition && partition->sub_graph) { - nodes_in_partition = partition->sub_graph->nodes.size(); - - if (nodes_in_partition == 1 && !is_qnn_ctx_model) { - const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); - - if (!node) { - LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; - is_valid_partition = false; - } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { - LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; - is_valid_partition = false; + if (!is_qnn_ctx_model) { + for (auto& partition : partitions) { + bool is_valid_partition = true; + size_t nodes_in_partition = 0; + + if (partition && partition->sub_graph) { + nodes_in_partition = partition->sub_graph->nodes.size(); + + if (nodes_in_partition == 1) { + const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + + if (!node) { + LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; + is_valid_partition = false; + } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { + LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; + is_valid_partition = false; + } } + } else { + LOGS(logger, ERROR) << "QNN EP: Invalid partition."; + is_valid_partition = false; } - } else { - LOGS(logger, ERROR) << "QNN EP: Invalid partition."; - is_valid_partition = false; - } - if (is_valid_partition) { - result.push_back(std::move(partition)); - num_of_supported_nodes += nodes_in_partition; - } - } + if (is_valid_partition) { + result.push_back(std::move(partition)); + num_of_supported_nodes += nodes_in_partition; + } + } // for + } // if (!is_qnn_ctx_model) } const size_t num_of_partitions = result.size(); From b79a2566c8600100e485c281678519c2a30f178f Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 1 Feb 2024 10:27:03 -0800 Subject: [PATCH 27/28] formating --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index a25b79e63fe44..1e6edfdb1e3ef 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -520,8 +520,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer result.push_back(std::move(partition)); num_of_supported_nodes += nodes_in_partition; } - } // for - } // if (!is_qnn_ctx_model) + } // for + } // if (!is_qnn_ctx_model) } const size_t num_of_partitions = result.size(); From 9e71147e3ab1f55e86ed0a755c0c391aec0a5824 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 1 Feb 2024 11:27:52 -0800 Subject: [PATCH 28/28] revert a change --- .../providers/qnn/qnn_execution_provider.cc | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 1e6edfdb1e3ef..b58f6e10df94c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -492,36 +492,34 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. // We also count the number of supported nodes in all valid partitions. - if (!is_qnn_ctx_model) { - for (auto& partition : partitions) { - bool is_valid_partition = true; - size_t nodes_in_partition = 0; - - if (partition && partition->sub_graph) { - nodes_in_partition = partition->sub_graph->nodes.size(); - - if (nodes_in_partition == 1) { - const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); - - if (!node) { - LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; - is_valid_partition = false; - } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { - LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; - is_valid_partition = false; - } + for (auto& partition : partitions) { + bool is_valid_partition = true; + size_t nodes_in_partition = 0; + + if (partition && partition->sub_graph) { + nodes_in_partition = partition->sub_graph->nodes.size(); + + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { + const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + + if (!node) { + LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; + is_valid_partition = false; + } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { + LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition."; + is_valid_partition = false; } - } else { - LOGS(logger, ERROR) << "QNN EP: Invalid partition."; - is_valid_partition = false; } + } else { + LOGS(logger, ERROR) << "QNN EP: Invalid partition."; + is_valid_partition = false; + } - if (is_valid_partition) { - result.push_back(std::move(partition)); - num_of_supported_nodes += nodes_in_partition; - } - } // for - } // if (!is_qnn_ctx_model) + if (is_valid_partition) { + result.push_back(std::move(partition)); + num_of_supported_nodes += nodes_in_partition; + } + } // for } const size_t num_of_partitions = result.size();