Skip to content

Commit

Permalink
Check the ep_cache_context and don't allow access outside the directo…
Browse files Browse the repository at this point in the history
…ry (microsoft#19174)

### Description
Check the ep_cache_context node property for EPContext node, and don't
allow relative path like "../file_path"
  • Loading branch information
HectorSVC authored Jan 18, 2024
1 parent 9da3e36 commit dadd3ea
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 2 deletions.
28 changes: 26 additions & 2 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,33 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
qnn_model);
}

std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name);
std::string external_qnn_ctx_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
ORT_RETURN_IF(external_qnn_ctx_binary_file_name.empty(), "The file path in ep_cache_context should not be empty.");
#ifdef _WIN32
onnxruntime::PathString external_qnn_context_binary_path = onnxruntime::ToPathString(external_qnn_ctx_binary_file_name);
auto ctx_file_path = std::filesystem::path(external_qnn_context_binary_path.c_str());
ORT_RETURN_IF(ctx_file_path.is_absolute(), "External mode should set ep_cache_context field with a relative path, but it is an absolute path: ",
external_qnn_ctx_binary_file_name);
auto relative_path = ctx_file_path.lexically_normal().make_preferred().wstring();
if (relative_path.find(L"..", 0) != std::string::npos) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory.");
}

std::filesystem::path context_binary_path = folder_path.append(relative_path);
#else
ORT_RETURN_IF(external_qnn_ctx_binary_file_name[0] == '/',
"External mode should set ep_cache_context field with a relative path, but it is an absolute path: ",
external_qnn_ctx_binary_file_name);
if (external_qnn_ctx_binary_file_name.find("..", 0) != std::string::npos) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context field has '..'. It's not allowed to point outside the directory.");
}
std::filesystem::path context_binary_path = folder_path.append(external_qnn_ctx_binary_file_name);
std::string file_full_path = context_binary_path.string();
#endif
if (!std::filesystem::is_regular_file(context_binary_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible.");
}

size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary);
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,135 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) {
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) {
const std::unordered_map<std::string, int> domain_to_version = {{"", 11}, {kMSDomain, 1}};
auto& logging_manager = DefaultLoggingManager();
onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
std::vector<int64_t> shape = {2, 3};
NodeArg* graph_input = MakeTestInput(helper, TestInputDef<float>(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}));
auto* graph_output = helper.MakeOutput<float>(shape);
Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
// The .. in the path will cause INVALID_GRAPH
ep_context_node.AddAttribute("ep_cache_context", external_bin_path);
ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
ep_context_node.AddAttribute("source", "QNN");
helper.SetGraphOutputs();
std::string model_data;
model.ToProto().SerializeToString(&model_data);

return model_data;
}

// Create a model with EPContext node. Set the node property ep_cache_context has ".."
// Verify that it return INVALID_GRAPH status
TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) {
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin");

SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSessionWrapper session_object{so, GetEnvironment()};

ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
}

// Create a model with EPContext node. Set the node property ep_cache_context has absolute path
// Verify that it return INVALID_GRAPH status
TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) {
#if defined(_WIN32)
std::string external_ctx_bin_path = "D:/qnn_context.bin";
#else
std::string external_ctx_bin_path = "/data/qnn_context.bin";
#endif
std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path);

SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSessionWrapper session_object{so, GetEnvironment()};

ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
}

// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist
// Verify that it return INVALID_GRAPH status
TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) {
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin");

SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSessionWrapper session_object{so, GetEnvironment()};

ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
}

// Create a model with EPContext node. Set the node property ep_cache_context to empty string
// Verify that it return INVALID_GRAPH status
TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) {
std::string model_data = CreateQnnCtxModelWithNonEmbedMode("");

SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
RunOptions run_options;
run_options.run_tag = so.session_logid;

InferenceSessionWrapper session_object{so, GetEnvironment()};

ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);
}

// Run QDQ model on HTP with 2 inputs
// 1st run will generate the Qnn context cache onnx file
// 2nd run will load and run from QDQ model + Qnn context cache model
Expand Down

0 comments on commit dadd3ea

Please sign in to comment.