Skip to content

Commit

Permalink
Add special handling if there is only 1 graph inside the cached QNN c…
Browse files Browse the repository at this point in the history
…ontext binary (#19594)

Add special handling if there is only 1 graph inside the cached QNN context binary. No need to make the EPContext node name match the QNN graph name. This is for better backward compatibility in case the QNN context model is generated before the PR for QNN context binary model support multi-partition.
  • Loading branch information
HectorSVC authored Feb 22, 2024
1 parent fe82fcc commit 0962241
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger) {
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage());
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
const logging::Logger& logger);

Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
Expand Down
15 changes: 10 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,16 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t

// More work to support multiple partition, how to map the graph name in compile to qnn graph name
// Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile
for (uint32_t i = 0; i < graph_count; ++i) {
std::string graph_name(graphs_info[i].graphInfoV1.graphName);
auto qnn_model_pos = qnn_models.find(graph_name);
ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names.");
ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i]));
if (1 == graph_count) {
auto qnn_model_pose = qnn_models.begin();
ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0]));
} else {
for (uint32_t i = 0; i < graph_count; ++i) {
std::string graph_name(graphs_info[i].graphInfoV1.graphName);
auto qnn_model_pos = qnn_models.find(graph_name);
ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names.");
ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i]));
}
}

qnn_sys_interface_.systemContextFree(sys_ctx_handle);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer,
context_cache_path,
qnn_backend_manager_.get(),
qnn_models));
qnn_models,
logger));

for (auto fused_node_and_graph : fused_nodes_and_graphs) {
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
Expand Down
83 changes: 81 additions & 2 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) {

InferenceSessionWrapper session_object{so, GetEnvironment()};

std::string provider_type = kCpuExecutionProvider;
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast<int>(qnn_ctx_model_data.size())));
// Verify the return status with code INVALID_GRAPH
Expand All @@ -486,7 +485,6 @@ std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) {
auto* graph_output = helper.MakeOutput<float>(shape);
Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
// The .. in the path will cause INVALID_GRAPH
ep_context_node.AddAttribute("ep_cache_context", external_bin_path);
ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
ep_context_node.AddAttribute("source", "QNN");
Expand Down Expand Up @@ -651,6 +649,87 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) {
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

// Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node
// Create another Onnx model which also reference to the bin file,
// but the node name is not same with the QNN graph name inside the bin file.
// This is to support backward compitable for the models generated before the PR that
// make context generation support multi-partition
TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphNameInCtx) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx";
std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin";
std::remove(context_binary_file.c_str());
std::remove(context_bin.string().c_str());

std::unordered_map<std::string, std::string> session_option_pairs;
session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1");
session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file);
session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0");

const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
const std::string op_type = "Atan";

// Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs.
// 1st run will generate the Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
provider_options,
14,
ExpectedEPNodeAssignment::All,
QDQTolerance(),
logging::Severity::kERROR,
"", // context model file path, not required for this inference
session_option_pairs);

// Check the Onnx skeleton file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// Check the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(context_bin));

const std::unordered_map<std::string, int> domain_to_version = {{"", 11}, {kMSDomain, 1}};
auto& logging_manager = DefaultLoggingManager();
onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
std::vector<int64_t> shape = {1, 2, 3};
NodeArg* graph_input = MakeTestInput(helper, TestInputDef<float>(shape, false, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}));
auto* graph_output = helper.MakeOutput<float>(shape);
Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
ep_context_node.AddAttribute("ep_cache_context", context_bin.string());
ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
ep_context_node.AddAttribute("source", "QNNExecutionProvider");
helper.SetGraphOutputs();
ASSERT_STATUS_OK(graph.Resolve());
std::string model_data;
model.ToProto().SerializeToString(&model_data);

// loads and run from Onnx skeleton file + Qnn context cache binary file

SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSessionWrapper session_object{so, GetEnvironment()};

ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK);

// Clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
ASSERT_EQ(std::remove(context_bin.string().c_str()), 0);
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
Expand Down

0 comments on commit 0962241

Please sign in to comment.