Skip to content

Commit

Permalink
fill some gaps in UT and fix an issue relate to context cache path
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Jan 26, 2024
1 parent 3b8e879 commit ff2c313
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 25 deletions.
21 changes: 15 additions & 6 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,23 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
return Status::OK();
}

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
// Figure out the real context cache file path
// return true if context cache file exists
bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// always try the path set by user first, it's the only way to set it if load model from memory
if (!customer_context_cache_path.empty()) {
context_cache_path = ToPathString(customer_context_cache_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
} else if (!model_pathstring.empty()) { // model loaded from file

Check warning on line 174 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L174

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:174:  At least two spaces is best between code and comments  [whitespace/comments] [2]
if (is_qnn_ctx_model) {
// it's a context cache model, just use the model path
context_cache_path = model_pathstring;
} else if (!model_pathstring.empty()) {
// this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}
}

return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph);

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path);
bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path);

Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,16 +566,18 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused

onnxruntime::PathString context_cache_path;
bool is_ctx_file_exist = false;
if (!is_qnn_ctx_model && context_cache_enabled_) {
if (is_qnn_ctx_model || context_cache_enabled_) {
const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph);
is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_,
graph_viewer_0.ModelPath().ToPathString(),
context_cache_path);
ORT_RETURN_IF(is_ctx_file_exist,
"The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ",
"Please remove the EP context model manually if you want to re-generate it.");
is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model,
context_cache_path_cfg_,
graph_viewer_0.ModelPath().ToPathString(),
context_cache_path);
}

ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_,
"The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ",
"Please remove the EP context model manually if you want to re-generate it.");

if (is_qnn_ctx_model) {
// Table<EPContext node name, QnnModel>, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model
// for this session (created from an EP context model), the graph_meta_id is new
Expand Down
36 changes: 28 additions & 8 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
using namespace ONNX_NAMESPACE;

Check warning on line 18 in onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/qnn_ep_context_test.cc#L18

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/test/providers/qnn/qnn_ep_context_test.cc:18:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
using namespace onnxruntime::logging;

Check warning on line 19 in onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/qnn_ep_context_test.cc#L19

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/test/providers/qnn/qnn_ep_context_test.cc:19:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

#define ORT_MODEL_FOLDER ORT_TSTR("testdata/")

// in test_main.cc
extern std::unique_ptr<Ort::Env> ort_env;

Expand Down Expand Up @@ -94,8 +92,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport) {
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str());
so.SetLogSeverityLevel(0);

so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
Expand Down Expand Up @@ -232,7 +228,6 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) {
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str());

so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so);
Expand Down Expand Up @@ -309,8 +304,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) {
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx";
std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin";
const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx";
std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin";

Check warning on line 308 in onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/qnn_ep_context_test.cc#L308

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/providers/qnn/qnn_ep_context_test.cc:308:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

std::unordered_map<std::string, std::string> session_option_pairs;
session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1");
Expand Down Expand Up @@ -340,6 +335,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) {
// Check the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin));

std::unordered_map<std::string, std::string> session_option_pairs2;
// Need to set the context file path since TestQDQModelAccuracy load the model from memory
session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file);
// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
Expand All @@ -348,7 +346,29 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) {
ExpectedEPNodeAssignment::All,
QDQTolerance(),
logging::Severity::kERROR,
context_binary_file);
context_binary_file,
session_option_pairs2);

// load the model from file
std::vector<char> buffer;
{
std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate);
if (!file)
ORT_THROW("Error reading model");
buffer.resize(narrow<size_t>(file.tellg()));
file.seekg(0, std::ios::beg);
if (!file.read(buffer.data(), buffer.size()))
ORT_THROW("Error reading model");
}

Ort::SessionOptions so; // No need to set the context file path in so since it's load from file

Check warning on line 364 in onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/qnn_ep_context_test.cc#L364

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/test/providers/qnn/qnn_ep_context_test.cc:364:  At least two spaces is best between code and comments  [whitespace/comments] [2]
so.AppendExecutionProvider("QNN", provider_options);
#ifdef _WIN32
std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end());
#else
std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end());
#endif
Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so);

// Clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/qnn/qnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe
model_proto.SerializeToString(&qnn_ctx_model_data);
// Run QNN context cache model on QNN EP and collect outputs.
InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options,
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep);
expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs);
} else {
// Run QDQ model on QNN EP and collect outputs.
// Only need to apply the extra session options to this QDQ model inference on QNN EP
Expand Down

0 comments on commit ff2c313

Please sign in to comment.