Skip to content

Commit

Permalink
undo tensorprotoutils.cc
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <[email protected]>
  • Loading branch information
liqunfu committed Mar 23, 2024
1 parent 72c590c commit ba32cfb
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 203 deletions.
175 changes: 3 additions & 172 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@
#include <emscripten.h>
#endif

#ifndef _WIN32
#include <sys/stat.h>
#endif

#include "core/common/gsl.h"
#include "core/common/logging/logging.h"
#include "core/common/narrow.h"
#include "core/common/path.h"
#include "core/common/span_utils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/endian_utils.h"
Expand Down Expand Up @@ -93,176 +88,12 @@ bool operator!=(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) {
return !(l == r);
}

} // namespace ONNX_NAMESPACE

namespace onnxruntime {
namespace {
#ifdef _WIN32
#else
std::string path_join(const std::string& origin, const std::string& append) {
if (origin.find_last_of(k_preferred_path_separator) != origin.length() - 1) {
return origin + k_preferred_path_separator + append;
}
return origin + append;
}

std::string clean_relative_path(const std::string& path) {
if (path.empty()) {
return ".";
}

std::string out;

size_t n = path.size();

size_t r = 0;
size_t dotdot = 0;

while (r < n) {
if (path[r] == k_preferred_path_separator) {
r++;
continue;
}

if (path[r] == '.' && (r + 1 == n || path[r + 1] == k_preferred_path_separator)) {
r++;
continue;
}

if (path[r] == '.' && path[r + 1] == '.' && (r + 2 == n || path[r + 2] == k_preferred_path_separator)) {
r += 2;

if (out.size() > dotdot) {
while (out.size() > dotdot && out.back() != k_preferred_path_separator) {
out.pop_back();
}
if (!out.empty())
out.pop_back();
} else {
if (!out.empty()) {
out.push_back(k_preferred_path_separator);
}

out.push_back('.');
out.push_back('.');
dotdot = out.size();
}

continue;
}

if (!out.empty() && out.back() != k_preferred_path_separator) {
out.push_back(k_preferred_path_separator);
}

for (; r < n && path[r] != k_preferred_path_separator; r++) {
out.push_back(path[r]);
}
}

if (out.empty()) {
out.push_back('.');
}

return out;
}
#endif

template <typename T>
T resolve_external_data_location(
const T& base_dir,
const T& location,
const std::string& tensor_name) {
if constexpr (std::is_same_v<T, std::wstring>) {
auto file_path = std::filesystem::path(location);
if (file_path.is_absolute()) {
ORT_THROW(
"Location of external TensorProto ( tensor name: ",
tensor_name,
") should be a relative path, but it is an absolute path: ",
ToUTF8String(location));
}
#if defined(__APPLE__) && TARGET_OS_IPHONE
// workaround 'wstring' is unavailable: introduced in iOS 13.0
auto relative_path = ToWideString(file_path.lexically_normal().make_preferred().string());
#else
auto relative_path = file_path.lexically_normal().make_preferred().wstring();
#endif
// Check that normalized relative path contains ".." on Windows.
if (relative_path.find(L"..", 0) != std::string::npos) {
ORT_THROW(
"Data of TensorProto ( tensor name: ",
tensor_name,
") should be file inside the ",
ToUTF8String(base_dir),
", but the '",
ToUTF8String(location),
"' points outside the directory");
}
std::wstring data_path = onnxruntime::ConcatPathComponent(base_dir, relative_path);
struct _stat64 buff;
if (data_path.empty() || (data_path[0] != '#' && _wstat64(data_path.c_str(), &buff) != 0)) {
ORT_THROW(
"Data of TensorProto ( tensor name: ",
tensor_name,
") should be stored in ",
ToUTF8String(location),
", but it doesn't exist or is not accessible.");
}
return data_path;
} else if constexpr (std::is_same_v<T, std::string>) {
if (location.empty()) {
ORT_THROW(
"Location of external TensorProto ( tensor name: ",
tensor_name,
") should not be empty.");
} else if (location[0] == '/') {
ORT_THROW(
"Location of external TensorProto ( tensor name: ",
tensor_name,
") should be a relative path, but it is an absolute path: ",
location);
}
std::string relative_path = clean_relative_path(location);
// Check that normalized relative path contains ".." on POSIX
if (relative_path.find("..", 0) != std::string::npos) {
ORT_THROW(
"Data of TensorProto ( tensor name: ",
tensor_name,
") should be file inside the ",
base_dir,
", but the '",
location,
"' points outside the directory");
}
std::string data_path = path_join(base_dir, relative_path);
// use stat64 to check whether the file exists
#if defined(__APPLE__) || defined(__wasm__) || !defined(__GLIBC__)
struct stat buffer; // APPLE, wasm and non-glic stdlibs do not have stat64
if (data_path.empty() || (data_path[0] != '#' && stat((data_path).c_str(), &buffer) != 0)) {
#else
struct stat64 buffer; // All POSIX under glibc except APPLE and wasm have stat64
if (data_path.empty() || (data_path[0] != '#' && stat64((data_path).c_str(), &buffer) != 0)) {
#endif
ORT_THROW(
"Data of TensorProto ( tensor name: ",
tensor_name,
") should be stored in ",
data_path,
", but it doesn't exist or is not accessible.");
}
// Do not allow symlinks or directories.
if (data_path.empty() || (data_path[0] != '#' && !S_ISREG(buffer.st_mode))) {
ORT_THROW(
"Data of TensorProto ( tensor name: ",
tensor_name,
") should be stored in ",
data_path,
", but it is not regular file.");
}
return data_path;
}
}
// This function doesn't support string tensors
static Status UnpackTensorWithRawDataImpl(const void* raw_data, size_t raw_data_len,
size_t expected_num_elements, size_t element_size,
Expand Down Expand Up @@ -314,8 +145,8 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot
external_file_path = location;
} else {
if (tensor_proto_dir != nullptr) {
external_file_path = resolve_external_data_location(
std::basic_string<ORTCHAR_T>(tensor_proto_dir), external_data_info->GetRelPath(), tensor_proto.name());
external_file_path = onnxruntime::ConcatPathComponent(tensor_proto_dir,
external_data_info->GetRelPath());
} else {
external_file_path = external_data_info->GetRelPath();
}
Expand Down
31 changes: 0 additions & 31 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3008,36 +3008,5 @@ TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) {
}
#endif

TEST(InferenceSessionTests, ModelWithAbsolutePathForExternalTensorData) {
SessionOptions so;

so.session_logid = "InferenceSessionTests.ModelWithAbsolutePathForExternalTensorData";

InferenceSession session_object{so, GetEnvironment()};
#ifdef _WIN32
ASSERT_STATUS_OK(session_object.Load("testdata/model_with_windows_absolute_path_for_external_tensor_data.onnx"));
#else
ASSERT_STATUS_OK(session_object.Load("testdata/model_with_linux_absolute_path_for_external_tensor_data.onnx"));
#endif
common::Status st = session_object.Initialize();

ASSERT_FALSE(st.IsOK());
EXPECT_THAT(st.ErrorMessage(),
::testing::ContainsRegex(".*Location of external TensorProto \\(.*\\) should be a relative path, but it is an absolute path: .*"));
}

TEST(InferenceSessionTests, ModelWithTraversalPathForExternalTensorData) {
SessionOptions so;

so.session_logid = "InferenceSessionTests.ModelWithTraversalPathForExternalTensorData";

InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load("testdata/model_with_traversal_path_for_external_tensor_data.onnx"));
common::Status st = session_object.Initialize();

ASSERT_FALSE(st.IsOK());
EXPECT_THAT(st.ErrorMessage(),
::testing::ContainsRegex(".*Data of TensorProto \\(.*\\) should be file inside the .*, but the \\'.*\\' points outside the directory"));
}
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit ba32cfb

Please sign in to comment.