Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce custom external data loader #21634

Merged
merged 10 commits into from
Aug 27, 2024
14 changes: 14 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "core/framework/external_data_loader.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
Expand Down Expand Up @@ -88,6 +89,19 @@ class IExecutionProvider {
return nullptr;
}

/**
* Returns an external data loader object that implements methods to load data from external sources.
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
*
* By default, framework will handle external data loading by loading the data into CPU memory and then copying
* it to the target device if required. So in most cases, it's not necessary to override this method. Specifically,
* in WebAssembly build, because the memory is limited and Web platform supports loading data from external sources
* directly into GPU memory, this method is overridden to provide a custom external data loader to avoid the extra
* CPU memory usage.
*/
virtual std::unique_ptr<onnxruntime::IExternalDataLoader> GetExternalDataLoader() const {
return nullptr;
}

/**
* Interface for performing kernel lookup within kernel registries.
* Abstracts away lower-level details about kernel registries and kernel matching.
Expand Down
100 changes: 100 additions & 0 deletions onnxruntime/core/framework/external_data_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/external_data_loader.h"
#ifndef SHARED_PROVIDER
#include "core/framework/tensor.h"
#endif
#if defined(__wasm__)
#include <emscripten.h>
#endif

namespace onnxruntime {

common::Status IExternalDataLoader::LoadTensor([[maybe_unused]] const Env& env,
[[maybe_unused]] const std::filesystem::path& data_file_path,
[[maybe_unused]] FileOffsetType data_offset,
[[maybe_unused]] SafeInt<size_t> data_length,
[[maybe_unused]] Tensor& tensor) const {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}

#if defined(__wasm__)

common::Status LoadWebAssemblyExternalData(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
ExternalDataLoadType load_type,
void* tensor_data) {
auto err_code = EM_ASM_INT(({
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (typeof Module == 'undefined' || !Module.MountedFiles) {
return 1; // "Module.MountedFiles" is not available.
}
let fileName = UTF8ToString($0 >>> 0);
if (fileName.startsWith('./')) {
fileName = fileName.substring(2);
}
const fileData = Module.MountedFiles.get(fileName);
if (!fileData) {
return 2; // File not found in preloaded files.
}
const offset = $1 >>> 0;
const length = $2 >>> 0;
const dataIdOrBuffer = $3 >>> 0;
const loadType = $4;

if (offset + length > fileData.byteLength) {
return 3; // Out of bounds.
}

try {
const data = fileData.subarray(offset, offset + length);
switch (loadType) {
case 0:
// Load external data to CPU memory.
// Copy the file data (fileData,offset,length) into WebAssembly memory
// (HEAPU8,buffer,length).
HEAPU8.set(data, dataIdOrBuffer);
break;
case 1:
// Load external data to GPU.
Module.jsepUploadExternalBuffer(dataIdOrBuffer, data);
break;
default:
return 4; // Unknown error occurred in memory copy.
}
return 0;
} catch {
return 4;
}
}),
data_file_path.c_str(),
static_cast<int32_t>(data_offset),
static_cast<int32_t>(data_length),
tensor_data,
static_cast<int32_t>(load_type));
const char* err_msg;
switch (err_code) {
case 0:
return Status::OK();
case 1:
err_msg = "Module.MountedFiles is not available.";
break;
case 2:
err_msg = "File not found in preloaded files.";
break;
case 3:
err_msg = "Out of bounds.";
break;
default:
err_msg = "Unknown error occurred in memory copy.";
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path,
"\", error: ", err_msg);
}

#endif

} // namespace onnxruntime
60 changes: 60 additions & 0 deletions onnxruntime/core/framework/external_data_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <functional>
#include <vector>
#include <filesystem>

#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/platform/env.h"

struct OrtMemoryInfo;

namespace onnxruntime {
#ifndef SHARED_PROVIDER
class Tensor;
#endif
class Stream;

namespace common {
class Status;
}

// Data transfer interface.
class IExternalDataLoader {
public:
virtual ~IExternalDataLoader() = default;

virtual bool CanLoad(const OrtMemoryInfo& target_memory_info) const = 0;

// Tensor should be already allocated with the correct memory info and size.
virtual common::Status LoadTensor(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
Tensor& tensor) const;
};

#if defined(__wasm__)

enum class ExternalDataLoadType {
CPU = 0,
#if defined(USE_JSEP)
WEBGPU_BUFFER = 1,
#endif
};

// Entry point for loading external data implementation using inline JavaScript.
common::Status LoadWebAssemblyExternalData(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
ExternalDataLoadType load_type,
void* tensor_data);

#endif

} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/framework/external_data_loader_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "core/framework/external_data_loader_manager.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
using namespace common;

Check warning on line 8 in onnxruntime/core/framework/external_data_loader_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/framework/external_data_loader_manager.cc:8: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

Status ExternalDataLoaderManager::RegisterExternalDataLoader(std::unique_ptr<IExternalDataLoader> external_data_loader) {

Check warning on line 10 in onnxruntime/core/framework/external_data_loader_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/external_data_loader_manager.cc:10: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 10 in onnxruntime/core/framework/external_data_loader_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/external_data_loader_manager.cc:10: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
if (nullptr == external_data_loader) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "external_data_loader registered is nullptr.");
}
external_data_loaders_.push_back(std::move(external_data_loader));

Check warning on line 14 in onnxruntime/core/framework/external_data_loader_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/external_data_loader_manager.cc:14: Add #include <utility> for move [build/include_what_you_use] [4]
return Status::OK();
}

const IExternalDataLoader* ExternalDataLoaderManager::GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const {

Check warning on line 18 in onnxruntime/core/framework/external_data_loader_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/external_data_loader_manager.cc:18: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (auto& external_data_loader : external_data_loaders_) {
if (!external_data_loader->CanLoad(target_memory_info)) {
continue;
}

return external_data_loader.get();
}
return nullptr;
}

} // namespace onnxruntime
28 changes: 28 additions & 0 deletions onnxruntime/core/framework/external_data_loader_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/status.h"
#include "core/common/common.h"
#include "core/framework/external_data_loader.h"

namespace onnxruntime {

// The external data loader manager manages all registered external data loaders to allow custom
// external data loading implemented by execution providers.
class ExternalDataLoaderManager {
public:
ExternalDataLoaderManager() = default;

common::Status RegisterExternalDataLoader(std::unique_ptr<IExternalDataLoader> external_data_loader);

const IExternalDataLoader* GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalDataLoaderManager);

// It's assumed that external data loaders in this array have no overlap in terms of copying functionality.
std::vector<std::unique_ptr<IExternalDataLoader>> external_data_loaders_;

Check warning on line 26 in onnxruntime/core/framework/external_data_loader_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/external_data_loader_manager.h:26: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

Check warning on line 26 in onnxruntime/core/framework/external_data_loader_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/external_data_loader_manager.h:26: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};
} // namespace onnxruntime
8 changes: 5 additions & 3 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ SessionState::SessionState(Graph& graph,
concurrency::ThreadPool* thread_pool,
concurrency::ThreadPool* inter_op_thread_pool,
const DataTransferManager& data_transfer_mgr,
const ExternalDataLoaderManager& external_data_loader_mgr,
const logging::Logger& logger,
profiling::Profiler& profiler,
const SessionOptions& sess_options,
Expand All @@ -78,6 +79,7 @@ SessionState::SessionState(Graph& graph,
thread_pool_(thread_pool),
inter_op_thread_pool_(inter_op_thread_pool),
data_transfer_mgr_(data_transfer_mgr),
external_data_loader_mgr_(external_data_loader_mgr),
sess_options_(sess_options),
prepacked_weights_container_(prepacked_weights_container)
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -1046,7 +1048,7 @@ Status SessionState::CreateSubgraphSessionState() {
auto subgraph_session_state =
std::make_unique<SessionState>(*subgraph, execution_providers_,
thread_pool_, inter_op_thread_pool_, data_transfer_mgr_,
logger_, profiler_, sess_options_,
external_data_loader_mgr_, logger_, profiler_, sess_options_,
prepacked_weights_container_, allocators_);

// Pass fused function manager to subgraph
Expand Down Expand Up @@ -1486,8 +1488,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
}
return Status::OK();
},
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func,
name_to_buffered_tensor_));
logger_, data_transfer_mgr_, external_data_loader_mgr_, *p_seq_exec_plan_, session_options,
memory_profile_func, name_to_buffered_tensor_));

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
// Record Weight allocation info on device
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "core/framework/allocation_planner.h"
#include "core/framework/callback.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/external_data_loader_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/stream_execution_context.h"
#include "core/framework/feeds_fetches_manager.h"
Expand Down Expand Up @@ -93,6 +94,7 @@ class SessionState {
concurrency::ThreadPool* thread_pool,
concurrency::ThreadPool* inter_op_thread_pool,
const DataTransferManager& data_transfer_mgr,
const ExternalDataLoaderManager& external_data_loader_mgr,
const logging::Logger& logger,
profiling::Profiler& profiler,
const SessionOptions& sess_options,
Expand Down Expand Up @@ -296,6 +298,8 @@ class SessionState {

const DataTransferManager& GetDataTransferMgr() const noexcept { return data_transfer_mgr_; }

const ExternalDataLoaderManager& GetExternalDataLoaderMgr() const noexcept { return external_data_loader_mgr_; }

InlinedVector<BufferUniquePtr>& GetMutableWeightsBuffers() noexcept { return weights_buffers_; }

const NodeIndexInfo& GetNodeIndexInfo() const;
Expand Down Expand Up @@ -513,6 +517,8 @@ class SessionState {

const DataTransferManager& data_transfer_mgr_;

const ExternalDataLoaderManager& external_data_loader_mgr_;

const SessionOptions& sess_options_;

std::optional<NodeIndexInfo> node_index_info_;
Expand Down
Loading
Loading