diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 2ff0714292fbc..4bb7378234187 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -162,14 +162,23 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default +// Figure out the real context cache file path +// return true if context cache file exists +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { + // always try the path set by user first, it's the only way to set it if load model from memory if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } else if (!model_pathstring.empty()) { // model loaded from file + if (is_qnn_ctx_model) { + // it's a context cache model, just use the model path + context_cache_path = model_pathstring; + } else if (!model_pathstring.empty()) { + // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } } return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 3a3799bd21bbf..b1360b4e576fa 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -43,9 +43,10 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index cdb30286d4489..c547e8a06bb65 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -566,16 +566,18 @@ Status QNNExecutionProvider::Compile(const std::vector& fused onnxruntime::PathString context_cache_path; bool is_ctx_file_exist = false; - if (!is_qnn_ctx_model && context_cache_enabled_) { + if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer_0.ModelPath().ToPathString(), - context_cache_path); - ORT_RETURN_IF(is_ctx_file_exist, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", - "Please remove the EP context model manually if you want to re-generate it."); + is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, + context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); } + ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); + if (is_qnn_ctx_model) { // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model // for this session (created from an EP context model), the graph_meta_id is new diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 3fd499e470f85..16bb6abf1e758 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -18,8 +18,6 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; -#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") - // in test_main.cc extern std::unique_ptr ort_env; @@ -94,8 +92,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - so.SetLogSeverityLevel(0); - so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); @@ -232,7 +228,6 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); @@ -309,8 +304,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -340,6 +335,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); + std::unordered_map session_option_pairs2; + // Need to set the context file path since TestQDQModelAccuracy load the model from memory + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), @@ -348,7 +346,29 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_binary_file, + session_option_pairs2); + + // load the model from file + std::vector buffer; + { + std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate); + if (!file) + ORT_THROW("Error reading model"); + buffer.resize(narrow(file.tellg())); + file.seekg(0, std::ios::beg); + if (!file.read(buffer.data(), buffer.size())) + ORT_THROW("Error reading model"); + } + + Ort::SessionOptions so; // No need to set the context file path in so since it's load from file + so.AppendExecutionProvider("QNN", provider_options); +#ifdef _WIN32 + std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#else + std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#endif + Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so); // Clean up ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bfe5bab318313..f4febd99ddae7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe model_proto.SerializeToString(&qnn_ctx_model_data); // Run QNN context cache model on QNN EP and collect outputs. InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); } else { // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP