diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 859fe4fdba4ae..3fd499e470f85 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -28,6 +28,110 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Create a model with Case + Add (quantized) +// input1 -> Add -> Q -> DQ \ +// Add -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildGraphWithQAndNonQ() { + return [](ModelTestBuilder& builder) { + // Creat non-quantized Add node + NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); + NodeArg* add1_ini_input1 = MakeTestInput(builder, TestInputDef({2, 3}, true, {0, 0, 0, 0, 0, 0})); + + auto* add1_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input1, add1_ini_input1}, {add1_output}); + + // Create quantized Add node2 + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add2_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); + + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + }; +} + +// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 Add node and 1 EPContext node +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildGraphWithQAndNonQ()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.SetLogSeverityLevel(0); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + int ep_context_node_count = 0; + int non_ep_context_node_count = 0; + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + ++ep_context_node_count; + } else { + ++non_ep_context_node_count; + } + } + + ASSERT_EQ(ep_context_node_count, 1); + ASSERT_EQ(non_ep_context_node_count, 1); + + Ort::SessionOptions so2; + // context file path is required if it's non-embed mode and the model is loaded from memroy + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AppendExecutionProvider("QNN", provider_options); + + std::string ctx_model_data; + ctx_model->ToProto().SerializeToString(&ctx_model_data); + Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + // Create a model with Case + Add (quantized) // cast_input -> Cast -> Q -> DQ \ // Add -> Q -> DQ -> output @@ -68,7 +172,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); @@ -90,6 +193,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_binary_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str());