diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 07b465c80745a..90ee8a46f66a9 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); } + if (all_ep_context_nodes.size() < 1) { + return Status::OK(); + } + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { for (auto& node : all_ep_context_nodes) { if (node_name == node->Name()) { @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers onnxruntime::PathString context_cache_path; PathString model_pathstring = graph.ModelPath().ToPathString(); - if (all_ep_context_nodes.size() > 0) { - if (!ep_context_path.empty()) { - context_cache_path = ToPathString(ep_context_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { #ifdef _WIN32 - std::wifstream fs(context_cache_path); + std::wifstream fs(context_cache_path); #else - std::ifstream fs(context_cache_path); + std::ifstream fs(context_cache_path); #endif - ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); - } + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph.DomainToVersionMap(), {}, logger); - auto& ep_graph = ep_context_model.MainGraph(); - ep_graph.SetDescription(graph.Description()); - - // Set inputs outputs explicitly to make sure the order is same as the user model. - auto inputs = graph.GetInputs(); - auto outputs = graph.GetOutputs(); - - InlinedVector ep_graph_inputs; - ep_graph_inputs.reserve(inputs.size()); - for (auto& input : inputs) { - auto input_arg = graph.GetNodeArg(input->Name()); - auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); - ep_graph_inputs.push_back(&ep_graph_input_arg); - } + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); - InlinedVector ep_graph_outputs; - ep_graph_outputs.reserve(outputs.size()); - for (auto& output : outputs) { - auto output_arg = graph.GetNodeArg(output->Name()); - auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - ep_graph_outputs.push_back(&ep_graph_output_arg); - } + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); - ep_graph.SetInputs(ep_graph_inputs); - ep_graph.SetOutputs(ep_graph_outputs); + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } - for (const auto& node : graph.Nodes()) { - // the fused node and EPContext node has same node name - auto ep_context_node = get_ep_context_node(node.Name()); - // Use EpContext node created by the EPs if name matched, otherwise use node from original model - if (ep_context_node.first) { - ep_graph.AddNode(*ep_context_node.second); - } else { - ep_graph.AddNode(node); - } - } + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } - // handle initializers - for (const auto& input : graph.GetInputsIncludingInitializers()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - // There initializer could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { - ep_graph.AddInitializedTensor(*initializer); - } - } + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); } + } - ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + // handle initializers + for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { + ep_graph.AddInitializedTensor(*initialized_tensor.second); + } } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 5d3f406f50612..c2e71081b898e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -12,34 +12,60 @@ namespace onnxruntime { namespace qnn { -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model) { - is_qnn_ctx_model = false; - for (const auto& fused_node_graph : fused_nodes_and_graphs) { - const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - int count = 0; - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { - is_qnn_ctx_model = true; +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { + // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider. + for (const auto& node : graph_viewer.Nodes()) { + if (EPCONTEXT_OP == node.OpType()) { + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") { + return true; } - ++count; } - ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node."); } - return Status::OK(); + return false; } -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) { - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); + bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer); + if (has_qnn_ep_context_node) { return true; } } return false; } +Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models) { + main_context_pos = -1; + for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) { + const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); + qnn_models.emplace(ep_context_node->Name(), + std::make_unique(logger, qnn_backend_manager)); + NodeAttrHelper node_helper(*ep_context_node); + int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); + if (1 == is_main_context) { + main_context_pos = static_cast(i); + } + } + + ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + return Status::OK(); +} + Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, @@ -60,32 +86,18 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); - const auto& graph = model->MainGraph(); - return GetEpContextFromGraph(GraphViewer(graph), - ctx_onnx_model_path, - qnn_backend_manager, - qnn_model); -} - -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model) { - const auto& node = graph_viewer.Nodes().begin(); - NodeAttrHelper node_helper(*node); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models) { + ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); + NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); if (is_embed_mode) { const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), - qnn_model); + qnn_models); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -133,23 +145,16 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), - qnn_model); + qnn_models); } -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - Status status; - if (is_qnn_ctx_model) { - status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); - } else if (is_ctx_cache_file_exist) { - status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); - } + std::unordered_map>& qnn_models) { + Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); + // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); } @@ -157,88 +162,37 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); - const auto& graph = GraphViewer(model->MainGraph()); - const auto& node = graph.Nodes().begin(); - NodeAttrHelper node_helper(*node); - model_name = graph.Name(); - model_description = graph.Description(); - graph_partition_name = node_helper.Get(PARTITION_NAME, ""); - cache_source = node_helper.Get(SOURCE, ""); - - return Status::OK(); -} - -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default +// Figure out the real context cache file path +// return true if context cache file exists +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { + // always try the path set by user first, it's the only way to set it if load model from memory if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } else if (!model_pathstring.empty()) { // model loaded from file + if (is_qnn_ctx_model) { + // it's a context cache model, just use the model path + context_cache_path = model_pathstring; + } else if (!model_pathstring.empty()) { + // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } } return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const std::string& graph_partition_name, - const logging::Logger& logger) { - std::string model_name_from_ctx_cache; - std::string model_description_from_ctx_cache; - std::string graph_partition_name_from_ctx_cache; - std::string cache_source; - auto status = GetMetadataFromEpContextModel(context_cache_path, - model_name_from_ctx_cache, - model_description_from_ctx_cache, - graph_partition_name_from_ctx_cache, - cache_source, - logger); - if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); - } - - // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT - if (cache_source != kQnnExecutionProvider) { - LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; - return Status::OK(); - } - - if (model_name != model_name_from_ctx_cache || - model_description != model_description_from_ctx_cache || - graph_partition_name != graph_partition_name_from_ctx_cache) { - std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", - model_name, " ", model_description, " ", graph_partition_name, - " vs epcontext: ", - model_name_from_ctx_cache, " ", - model_description_from_ctx_cache, " ", - graph_partition_name_from_ctx_cache); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); - } - - return Status::OK(); -} - -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger) { +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger) { auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index ba6fe23ecd56e..b1360b4e576fa 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,59 +28,44 @@ static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model); +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer); +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs); + +Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models); Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, onnxruntime::Graph& graph); -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); - -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger); +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models); -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger); - -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const std::string& graph_partition_name, - const logging::Logger& logger); - -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger); - -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger); + std::unordered_map>& qnn_models); + +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 973b81d337c81..5f0b87c7cb9d7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -517,7 +517,8 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } -Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -550,8 +551,9 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t graphs_info = binary_info->contextBinaryInfoV2.graphs; } - ORT_RETURN_IF(graph_count > 1, "Load from Qnn cached context only support 1 sub-graph."); - ORT_RETURN_IF(graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size(); + ORT_RETURN_IF(graph_count != qnn_models.size(), "Graph count from QNN context not equal to EPContext node count."); ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); @@ -571,7 +573,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } qnn_sys_interface_.systemContextFree(sys_ctx_handle); sys_ctx_handle = nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index f7b8947ab84bb..36375522b5a0a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -87,7 +87,8 @@ class QnnBackendManager { std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); - Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model); + Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 869d9326d9232..314cab4a36ca9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -97,7 +97,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, std::unordered_map node_unit_map; std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); - const auto& graph_name = graph_viewer.Name(); + // This name must be same with the EPContext node name + const auto& graph_name = fused_node.Name(); ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 5f4e2e62f063e..b58f6e10df94c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -150,6 +150,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } static const std::string BACKEND_PATH = "backend_path"; @@ -318,14 +319,27 @@ std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool load_from_cached_context, + bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Enable Qnn context cache requires the whole graph partitioned to Qnn EP - // Blindly filter in all nodes if context cache is enabled - if (load_from_cached_context) { + // Filter in the EPContext node for QNN + if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { - supported_nodes.insert(&node); + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(qnn::SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { + LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() + << "] name: [" << node.Name() + << "] Operator type: [EPContext" + << "] index: [" << node.Index() << "]"; + supported_nodes.insert(&node); + } } return supported_nodes; } @@ -410,22 +424,11 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } const auto& logger = *GetLogger(); - bool load_from_cached_context = false; - bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); - if (is_qnn_ctx_model) { - load_from_cached_context = true; - } + bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); - // This is for case: QDQ model + Onnx Qnn context cache model - if (context_cache_enabled_ && !is_qnn_ctx_model) { - onnxruntime::PathString context_cache_path; - load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - } - - // Load from cached context will load the QnnSystem lib and skip the Qnn context creation - auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); + // It will load the QnnSystem lib if is_qnn_ctx_model=true, and + // delay the Qnn context creation to Compile() using the cached context binary + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -443,7 +446,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), - load_from_cached_context, logger); + is_qnn_ctx_model, logger); // Helper function that returns a string that lists all unsupported nodes. // Ex: { name: mul_123, type: Mul }, {}, ... @@ -496,7 +499,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (partition && partition->sub_graph) { nodes_in_partition = partition->sub_graph->nodes.size(); - if (nodes_in_partition == 1) { + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); if (!node) { @@ -516,7 +519,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer result.push_back(std::move(partition)); num_of_supported_nodes += nodes_in_partition; } - } + } // for } const size_t num_of_partitions = result.size(); @@ -527,7 +530,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Print list of unsupported nodes to the ERROR logger if the CPU EP // has been disabled for this inference session. - if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + if (!is_qnn_ctx_model && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); } @@ -618,64 +621,76 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); - Node& fused_node = fused_nodes_and_graphs[0].fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - bool is_qnn_ctx_model = false; - ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); onnxruntime::PathString context_cache_path; - bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - const std::string& model_name = graph_viewer.GetGraph().Name(); - const std::string& model_description = graph_viewer.GetGraph().Description(); - const std::string& graph_meta_id = fused_node.Name(); - if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, - model_name, - model_description, - graph_meta_id, - logger)); - } - - if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); - std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + bool is_ctx_file_exist = false; + if (is_qnn_ctx_model || context_cache_enabled_) { + const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); + is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, + context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); + } + + ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); + + if (is_qnn_ctx_model) { + // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model + // for this session (created from an EP context model), the graph_meta_id is new + std::unordered_map> qnn_models; + + int main_context_pos = -1; + ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), + logger, main_context_pos, qnn_models)); + + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, context_cache_path, - is_qnn_ctx_model, - is_ctx_file_exist, qnn_backend_manager_.get(), - *(qnn_model.get()), - logger)); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); - - // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] - // the name here should be same with context->node_name in compute_info - qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + qnn_models)); + + for (auto fused_node_and_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& fused_node = fused_node_and_graph.fused_node; + const std::string& graph_meta_id = fused_node.Name(); + std::string key = ep_context_node->Name(); + ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); + auto qnn_model = std::move(qnn_models[key]); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + + // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] + // the name here must be same with context->node_name in compute_info + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + + ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + } - ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); - if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + // Generate QNN context model if it's QDQ model + context_cache_enabled=true + not exist already + if (!is_qnn_ctx_model && context_cache_enabled_ && !is_ctx_file_exist) { + // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), - context_buffer.get(), - buffer_size, - qnn_backend_manager_->GetSdkVersion(), - fused_nodes_and_graphs, - qnn_models_, - context_cache_path, - qnn_context_embed_mode_, - logger)); + ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), + context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + context_cache_path, + qnn_context_embed_mode_, + logger)); } return Status::OK(); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index c50b1002fa8c8..4e1aef2c40b2b 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -613,94 +613,6 @@ static GetTestModelFn BuildCastAddTestCase() { }; } -// Test that models with 2 inputs which has different data type can still generate the context binary -TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - BuildCastAddTestCase()(helper); - helper.SetGraphOutputs(); - ASSERT_STATUS_OK(model.MainGraph().Resolve()); - - // Serialize the model to a string. - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Generate context cache model from the ONNX models with 2 inputs. -// The generated model should have same input order. -// The input ONNX model is created in the way that the model inputs order -// is different with the order in the graph (topological order). -// It cause issue if the generated model doesn't set the inputs/outputs explicitly. -TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); - auto inputs = model->MainGraph().GetInputs(); - EXPECT_TRUE(inputs.size() == 2); - EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); - EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) // the value pair(1, 0.00392156886) at index #1 don't match, // which is -0.996078 from 1 diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc new file mode 100644 index 0000000000000..b1f3b52e77553 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -0,0 +1,657 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/inference_session.h" + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +// in test_main.cc +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Create a model with Case + Add (quantized) +// input1 -> Add -> Q -> DQ \ +// FusedMatMul -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { + return [single_ep_node](ModelTestBuilder& builder) { + // Creat non-quantized Add node1 + NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 2}, false, {0, 1, 0, 1})); + NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); + + auto* add1_output = builder.MakeIntermediate(); + builder.AddNode("FusedMatMul", {input1, add1_ini_input2}, {add1_output}, kMSDomain); + + // Create quantized Add node2 + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); + auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add2_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); + + if (single_ep_node) { + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + } else { + auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); + + auto* add3_output = builder.MakeIntermediate(); + builder.AddNode("FusedMatMul", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); + + // Create quantized Add node4 + auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); + auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add4_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add4_input1_qdq, add4_input2_qdq}, {add4_output}); + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add4_output, q_parameter.scale, q_parameter.zero_point); + } + }; +} + +void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + int ep_context_node_count = 0; + int non_ep_context_node_count = 0; + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + ++ep_context_node_count; + } else { + ++non_ep_context_node_count; + } + } + + int expected_node_count = single_ep_node ? 1 : 2; + ASSERT_EQ(ep_context_node_count, expected_node_count); + ASSERT_EQ(non_ep_context_node_count, expected_node_count); + + Ort::SessionOptions so2; + // context file path is required if it's non-embed mode and the model is loaded from memory + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AppendExecutionProvider("QNN", provider_options); + + std::string ctx_model_data; + ctx_model->ToProto().SerializeToString(&ctx_model_data); + Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 Add node and 1 EPContext node +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { + bool single_ep_node = true; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + +// Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary +// The generated Onnx model has 2 Add nodes and 1 EPContext nodes +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { + bool single_ep_node = false; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + +// Create a model with Case + Add (quantized) +// cast_input -> Cast -> Q -> DQ \ +// Add -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildCastAddTestCase() { + return [](ModelTestBuilder& builder) { + // Creat Cast node int32 -> float32 + NodeArg* cast_input = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); + + auto* cast_output = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output}); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + + // Create Add node + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add_input1_qdq = AddQDQNodePair(builder, cast_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + auto* add_input2_qdq = AddQDQNodePair(builder, add_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output}); + + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add_output, q_parameter.scale, q_parameter.zero_point); + }; +} + +// Test that models with 2 inputs which has different data type can still generate the context binary +TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildCastAddTestCase()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Generate context cache model from the ONNX models with 2 inputs. +// The generated model should have same input order. +// The input ONNX model is created in the way that the model inputs order +// is different with the order in the graph (topological order). +// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto inputs = model->MainGraph().GetInputs(); + EXPECT_TRUE(inputs.size() == 2); + EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); + EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + std::remove(context_binary_file.c_str()); + std::remove(qnn_ctx_bin.c_str()); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); + + std::unordered_map session_option_pairs2; + // Need to set the context file path since TestQDQModelAccuracy load the model from memory + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file, + session_option_pairs2); + + // load the model from file + std::vector buffer; + { + std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate); + if (!file) + ORT_THROW("Error reading model"); + buffer.resize(narrow(file.tellg())); + file.seekg(0, std::ios::beg); + if (!file.read(buffer.data(), buffer.size())) + ORT_THROW("Error reading model"); + } + + Ort::SessionOptions so; // No need to set the context file path in so since it's load from file + so.AppendExecutionProvider("QNN", provider_options); +#ifdef _WIN32 + std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#else + std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#endif + Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); +} + +// Run QDQ model on HTP 2 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + // Delete the Qnn context cache binary file + EXPECT_TRUE(std::filesystem::remove(context_bin)); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + onnx::ModelProto model_proto; + onnxruntime::Model qnn_ctx_model; + // Load the QNN context cache model from path specified + ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); + std::string qnn_ctx_model_data; + model_proto.SerializeToString(&qnn_ctx_model_data); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + std::string provider_type = kCpuExecutionProvider; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + // The .. in the path will cause INVALID_GRAPH + ep_context_node.AddAttribute("ep_cache_context", external_bin_path); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNN"); + helper.SetGraphOutputs(); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + return model_data; +} + +// Create a model with EPContext node. Set the node property ep_cache_context has ".." +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +#if defined(_WIN32) + std::string external_ctx_bin_path = "D:/qnn_context.bin"; +#else + std::string external_ctx_bin_path = "/data/qnn_context.bin"; +#endif + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to empty string +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Run QDQ model on HTP with 2 inputs +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); + const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Add"; + + // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bfe5bab318313..f4febd99ddae7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe model_proto.SerializeToString(&qnn_ctx_model_data); // Run QNN context cache model on QNN EP and collect outputs. InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); } else { // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 1e938ae9e334b..2f3b0e84a123e 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -723,381 +723,6 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { true); // Use com.microsoft domain for Q/DQ ops } -// Run QDQ model on HTP 3 times -// 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Run QDQ model on HTP 3 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file -// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file -TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); - - // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); - ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); -} - -// Run QDQ model on HTP 2 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH -TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - EXPECT_TRUE(std::filesystem::exists(context_bin)); - // Delete the Qnn context cache binary file - EXPECT_TRUE(std::filesystem::remove(context_bin)); - - // loads and run from Onnx skeleton file + Qnn context cache binary file - onnx::ModelProto model_proto; - onnxruntime::Model qnn_ctx_model; - // Load the QNN context cache model from path specified - ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); - std::string qnn_ctx_model_data; - model_proto.SerializeToString(&qnn_ctx_model_data); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - std::string provider_type = kCpuExecutionProvider; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { - const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; - auto& logging_manager = DefaultLoggingManager(); - onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - std::vector shape = {2, 3}; - NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); - auto* graph_output = helper.MakeOutput(shape); - Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); - ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH - ep_context_node.AddAttribute("ep_cache_context", external_bin_path); - ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); - ep_context_node.AddAttribute("source", "QNN"); - helper.SetGraphOutputs(); - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - return model_data; -} - -// Create a model with EPContext node. Set the node property ep_cache_context has ".." -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context has absolute path -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { -#if defined(_WIN32) - std::string external_ctx_bin_path = "D:/qnn_context.bin"; -#else - std::string external_ctx_bin_path = "/data/qnn_context.bin"; -#endif - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to empty string -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Run QDQ model on HTP with 2 inputs -// 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); - const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Add"; - - // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { ProviderOptions provider_options;