diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 6550b9aef1b8c..caef6a3d06f17 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -60,10 +60,10 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { +Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); @@ -74,10 +74,10 @@ Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_m qnn_model); } -Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model) { +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model) { const auto& node = graph_viewer.Nodes().begin(); NodeAttrHelper node_helper(*node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -112,12 +112,27 @@ Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewe qnn_model); } -Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger) { +Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { + if (is_qnn_ctx_model) { + return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); + } else if (is_ctx_cache_file_exist) { + return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); + } + return Status::OK(); +} + +Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); @@ -132,33 +147,29 @@ Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ct return Status::OK(); } -bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_name, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring) { - model_name_ = model_name; - model_description_ = model_description; +bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + std::string& context_cache_path) { // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default if (!customer_context_cache_path.empty()) { - context_cache_path_ = customer_context_cache_path; + context_cache_path = customer_context_cache_path; } else if (!model_pathstring.empty()) { - context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx"; + context_cache_path = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx"; } - ctx_file_exists_ = std::filesystem::is_regular_file(context_cache_path_) && std::filesystem::exists(context_cache_path_); - return ctx_file_exists_; + return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name, - const std::string& graph_partition_name, - const logging::Logger& logger) { - ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); - +Status ValidateWithContextFile(const std::string& context_cache_path, + const std::string& model_name, + const std::string& model_description, + const std::string& graph_partition_name, + const logging::Logger& logger) { std::string model_name_from_ctx_cache; std::string model_description_from_ctx_cache; std::string graph_partition_name_from_ctx_cache; std::string cache_source; - ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_, + ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path, model_name_from_ctx_cache, model_description_from_ctx_cache, graph_partition_name_from_ctx_cache, @@ -175,31 +186,34 @@ Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_na " is different with target: " + model_name + ". Please make sure the context binary file matches the model."); - ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache, + ORT_RETURN_IF(model_description != model_description_from_ctx_cache, "Model description from context cache metadata: " + model_description_from_ctx_cache + - " is different with target: " + model_description_ + + " is different with target: " + model_description + ". Please make sure the context binary file matches the model."); - ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_, + ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache, "Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache + " is different with target: " + graph_partition_name + ". You may need to re-generate the context binary file."); - get_capability_round_2_ = true; return Status::OK(); } -Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const logging::Logger& logger) { +Status GenerateCtxCacheOnnxModel(const std::string model_name, + const std::string model_description, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const std::string& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger) { std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, logger); auto& graph = model.MainGraph(); - graph.SetDescription(model_description_); + graph.SetDescription(model_description); using namespace ONNX_NAMESPACE; int index = 0; @@ -227,13 +241,13 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, // Only dump the context buffer once since all QNN graph are in one single context if (0 == index) { - if (qnn_context_embed_mode_) { + if (qnn_context_embed_mode) { std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin"); - std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string()); - std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary); + std::string context_bin_path(context_cache_path + "_" + graph_name + ".bin"); + std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); + std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { LOGS(logger, ERROR) << "Failed to open create context file."; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file."); @@ -244,7 +258,7 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, } else { ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); } - int64_t embed_mode = qnn_context_embed_mode_ ? static_cast(1) : static_cast(0); + int64_t embed_mode = qnn_context_embed_mode ? static_cast(1) : static_cast(0); ep_node.AddAttribute(EMBED_MODE, embed_mode); ep_node.AddAttribute(EP_SDK_VER, sdk_build_version); ep_node.AddAttribute(PARTITION_NAME, graph_name); @@ -252,7 +266,7 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, ++index; } ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_)); + ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 4a7c1c778bdbc..ed425a0f562a4 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,78 +38,50 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -class QnnCacheModelHandler { - public: - QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) { - } - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler); - - Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - if (is_qnn_ctx_model) { - return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); - } else if (is_ctx_cache_file_exist) { - return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); - } - return Status::OK(); - } - - bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_name, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring); - - bool GetIsContextCacheFileExists() const { - return ctx_file_exists_; - } - - Status ValidateWithContextFile(const std::string& model_name, - const std::string& graph_name, - const logging::Logger& logger); +bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + std::string& context_cache_path); + +Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger); - Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger); - - Status GenerateCtxCacheOnnxModel(unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const logging::Logger& logger); - - private: - Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model); + +Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const std::string& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, QnnModel& qnn_model, const logging::Logger& logger); - Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model); - - private: - bool is_metadata_ready_ = false; - // model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model - std::string model_name_ = ""; - std::string model_description_ = ""; - std::string graph_partition_name_ = ""; - std::string cache_source_ = ""; - - std::string context_cache_path_ = ""; - bool ctx_file_exists_ = false; - bool get_capability_round_2_ = false; - bool qnn_context_embed_mode_ = true; -}; // QnnCacheModelHandler +Status ValidateWithContextFile(const std::string& context_cache_path, + const std::string& model_name, + const std::string& model_description, + const std::string& graph_partition_name, + const logging::Logger& logger); +Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger); + +Status GenerateCtxCacheOnnxModel(const std::string model_name, + const std::string model_description, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const std::string& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index aac82c89d6f49..4edccea661642 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -22,7 +22,6 @@ namespace onnxruntime { namespace qnn { class QnnModel; -class QnnCacheModelHandler; class QnnBackendManager { public: diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f198e0264410f..4772da104c389 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,6 +16,7 @@ #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" namespace onnxruntime { @@ -134,16 +135,15 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path"; auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH); if (context_cache_path_pos != provider_options_map.end()) { - context_cache_path_ = context_cache_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_; + context_cache_path_cfg_ = context_cache_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } - bool qnn_context_embed_mode = true; static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode"; auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE); if (context_cache_embed_mode_pos != provider_options_map.end()) { - qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1"; - LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode; + qnn_context_embed_mode_ = context_cache_embed_mode_pos->second == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; } static const std::string BACKEND_PATH = "backend_path"; @@ -206,7 +206,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_performance_mode_, context_priority_, std::move(qnn_saver_path)); - qnn_cache_model_handler_ = std::make_unique(qnn_context_embed_mode); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -343,10 +342,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // This is for case: QDQ model + Onnx Qnn context cache model if (context_cache_enabled_ && !is_qnn_ctx_model) { - load_from_cached_context = qnn_cache_model_handler_->IsContextCacheFileExists(context_cache_path_, - graph_viewer.Name(), - graph_viewer.Description(), - graph_viewer.ModelPath().ToPathString()); + std::string context_cache_path; + load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer.ModelPath().ToPathString(), + context_cache_path); } // Load from cached context will load the QnnSystem lib and skip the Qnn context creation @@ -537,24 +536,32 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_qnn_ctx_model = false; ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); - bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists(); + std::string context_cache_path; + bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer.ModelPath().ToPathString(), + context_cache_path); + const std::string& model_name = graph_viewer.GetGraph().Name(); + LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Name(): " << model_name; + LOGS(logger, VERBOSE) << "graph_viewer.GetGraph().Description(): " << graph_viewer.GetGraph().Description(); if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->ValidateWithContextFile(graph_viewer.Name(), - graph_viewer.Name(), - logger)); + ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, + model_name, + graph_viewer.GetGraph().Description(), + graph_viewer.Name(), + logger)); } if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer, - context_cache_path_, - is_qnn_ctx_model, - is_ctx_file_exist, - qnn_backend_manager_.get(), - *(qnn_model.get()), - logger)); + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + context_cache_path_cfg_, + is_qnn_ctx_model, + is_ctx_file_exist, + qnn_backend_manager_.get(), + *(qnn_model.get()), + logger)); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); @@ -572,12 +579,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(), - buffer_size, - qnn_backend_manager_->GetSdkVersion(), - fused_nodes_and_graphs, - qnn_models_, - logger)); + ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, + graph_viewer.GetGraph().Description(), + context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + context_cache_path, + qnn_context_embed_mode_, + logger)); } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index cf0bff8890d0c..8c99a916a6f69 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -8,7 +8,6 @@ #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" -#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_graph_configs_helper.h" namespace onnxruntime { @@ -71,10 +70,10 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_map> qnn_models_; uint32_t rpc_control_latency_ = 0; bool context_cache_enabled_ = false; - std::string context_cache_path_ = ""; + std::string context_cache_path_cfg_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. - std::unique_ptr qnn_cache_model_handler_; qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL; + bool qnn_context_embed_mode_ = true; }; } // namespace onnxruntime