Skip to content

Commit

Permalink
Introduce custom external data loader (#21634)
Browse files Browse the repository at this point in the history
### Description

This PR introduces support for custom external data loader. An EP can
register a custom external data loader to override the default behavior,
making it possible to upload initializers directly to GPU.



### Motivation and Context

- In ONNX Runtime Web, WebAssembly uses 32-bit as pointer type
(`sizeof(size_t)==4`), which means there is a 4GB hard limit on the
maximum memory. As the ONNX models get larger, this becomes a blocker
for supporting medium-sized language models.

- ORT runs out of memory because the current code always loads data into
CPU memory, including the .onnx file (protobuf) and external data
file(s). However, if using GPU EP, the big data does not need to be kept
on CPU because the only thing that ORT does is to load the data into
memory, upload to GPU and then release them.

- Some platforms has offered developers way to upload data directly to
GPU. For example, webgpu allows uploading from any ArrayBuffer (it can
be a side buffer, not count into the 4GB) to GPU directly. This helps to
keep the CPU memory usage significantly.

### Design

Class `ExternalDataLoader` and `ExternalDataLoaderManager` are
introduced. They are similar to `DataTransfer` and
`DataTransferManager`. `InferenceSession` owns the manager object, and
`SessionState` keeps a reference to it.

Added a new method `GetExternalDataLoader` in `IExecutionProvider`. An
EP can override the method to register an instance of custom external
data loader.

The key function in a `ExternalDataLoader` class is method `LoadTensor`:

```c++
  // the tensor is pre-created using the TensorProto info of the initializer and the MemoryInfo (from allocation plan).
  virtual common::Status LoadTensor(const Env& env,
                                    const std::filesystem::path& data_file_path,
                                    FileOffsetType data_offset,
                                    SafeInt<size_t> data_length,
                                    Tensor& tensor) const;
```

This function can be registered by EP, going through a few layers and
eventually get into `DeserializeTensorProto()` in the finalizing stage
of session initialization. In this step, initializer tensors are
created. Behavior is changed to first look up for a registered external
data loader that can handle the current memory info. If any instance is
available, use the loader; otherwise respect the old code path.
  • Loading branch information
fs-eire authored Aug 27, 2024
1 parent b7f09d4 commit d2a1b7a
Show file tree
Hide file tree
Showing 25 changed files with 448 additions and 102 deletions.
14 changes: 14 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "core/framework/external_data_loader.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
Expand Down Expand Up @@ -88,6 +89,19 @@ class IExecutionProvider {
return nullptr;
}

/**
* Returns an external data loader object that implements methods to load data from external sources.
*
* By default, framework will handle external data loading by loading the data into CPU memory and then copying
* it to the target device if required. So in most cases, it's not necessary to override this method. Specifically,
* in WebAssembly build, because the memory is limited and Web platform supports loading data from external sources
* directly into GPU memory, this method is overridden to provide a custom external data loader to avoid the extra
* CPU memory usage.
*/
virtual std::unique_ptr<onnxruntime::IExternalDataLoader> GetExternalDataLoader() const {
return nullptr;
}

/**
* Interface for performing kernel lookup within kernel registries.
* Abstracts away lower-level details about kernel registries and kernel matching.
Expand Down
100 changes: 100 additions & 0 deletions onnxruntime/core/framework/external_data_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/external_data_loader.h"
#ifndef SHARED_PROVIDER
#include "core/framework/tensor.h"
#endif
#if defined(__wasm__)
#include <emscripten.h>
#endif

namespace onnxruntime {

common::Status IExternalDataLoader::LoadTensor([[maybe_unused]] const Env& env,
[[maybe_unused]] const std::filesystem::path& data_file_path,
[[maybe_unused]] FileOffsetType data_offset,
[[maybe_unused]] SafeInt<size_t> data_length,
[[maybe_unused]] Tensor& tensor) const {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}

#if defined(__wasm__)

common::Status LoadWebAssemblyExternalData(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
ExternalDataLoadType load_type,
void* tensor_data) {
auto err_code = EM_ASM_INT(({
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (typeof Module == 'undefined' || !Module.MountedFiles) {
return 1; // "Module.MountedFiles" is not available.
}
let fileName = UTF8ToString($0 >>> 0);
if (fileName.startsWith('./')) {
fileName = fileName.substring(2);
}
const fileData = Module.MountedFiles.get(fileName);
if (!fileData) {
return 2; // File not found in preloaded files.
}
const offset = $1 >>> 0;
const length = $2 >>> 0;
const dataIdOrBuffer = $3 >>> 0;
const loadType = $4;

if (offset + length > fileData.byteLength) {
return 3; // Out of bounds.
}

try {
const data = fileData.subarray(offset, offset + length);
switch (loadType) {
case 0:
// Load external data to CPU memory.
// Copy the file data (fileData,offset,length) into WebAssembly memory
// (HEAPU8,buffer,length).
HEAPU8.set(data, dataIdOrBuffer);
break;
case 1:
// Load external data to GPU.
Module.jsepUploadExternalBuffer(dataIdOrBuffer, data);
break;
default:
return 4; // Unknown error occurred in memory copy.
}
return 0;
} catch {
return 4;
}
}),
data_file_path.c_str(),
static_cast<int32_t>(data_offset),
static_cast<int32_t>(data_length),
tensor_data,
static_cast<int32_t>(load_type));
const char* err_msg;
switch (err_code) {
case 0:
return Status::OK();
case 1:
err_msg = "Module.MountedFiles is not available.";
break;
case 2:
err_msg = "File not found in preloaded files.";
break;
case 3:
err_msg = "Out of bounds.";
break;
default:
err_msg = "Unknown error occurred in memory copy.";
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path,
"\", error: ", err_msg);
}

#endif

} // namespace onnxruntime
60 changes: 60 additions & 0 deletions onnxruntime/core/framework/external_data_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <functional>
#include <vector>
#include <filesystem>

#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/platform/env.h"

struct OrtMemoryInfo;

namespace onnxruntime {
#ifndef SHARED_PROVIDER
class Tensor;
#endif
class Stream;

namespace common {
class Status;
}

// Data transfer interface.
class IExternalDataLoader {
public:
virtual ~IExternalDataLoader() = default;

virtual bool CanLoad(const OrtMemoryInfo& target_memory_info) const = 0;

// Tensor should be already allocated with the correct memory info and size.
virtual common::Status LoadTensor(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
Tensor& tensor) const;
};

#if defined(__wasm__)

enum class ExternalDataLoadType {
CPU = 0,
#if defined(USE_JSEP)
WEBGPU_BUFFER = 1,
#endif
};

// Entry point for loading external data implementation using inline JavaScript.
common::Status LoadWebAssemblyExternalData(const Env& env,
const std::filesystem::path& data_file_path,
FileOffsetType data_offset,
SafeInt<size_t> data_length,
ExternalDataLoadType load_type,
void* tensor_data);

#endif

} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/framework/external_data_loader_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/external_data_loader_manager.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
using namespace common;

Status ExternalDataLoaderManager::RegisterExternalDataLoader(std::unique_ptr<IExternalDataLoader> external_data_loader) {
if (nullptr == external_data_loader) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "external_data_loader registered is nullptr.");
}
external_data_loaders_.push_back(std::move(external_data_loader));
return Status::OK();
}

const IExternalDataLoader* ExternalDataLoaderManager::GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const {
for (auto& external_data_loader : external_data_loaders_) {
if (!external_data_loader->CanLoad(target_memory_info)) {
continue;
}

return external_data_loader.get();
}
return nullptr;
}

} // namespace onnxruntime
28 changes: 28 additions & 0 deletions onnxruntime/core/framework/external_data_loader_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/status.h"
#include "core/common/common.h"
#include "core/framework/external_data_loader.h"

namespace onnxruntime {

// The external data loader manager manages all registered external data loaders to allow custom
// external data loading implemented by execution providers.
class ExternalDataLoaderManager {
public:
ExternalDataLoaderManager() = default;

common::Status RegisterExternalDataLoader(std::unique_ptr<IExternalDataLoader> external_data_loader);

const IExternalDataLoader* GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalDataLoaderManager);

// It's assumed that external data loaders in this array have no overlap in terms of copying functionality.
std::vector<std::unique_ptr<IExternalDataLoader>> external_data_loaders_;
};
} // namespace onnxruntime
8 changes: 5 additions & 3 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ SessionState::SessionState(Graph& graph,
concurrency::ThreadPool* thread_pool,
concurrency::ThreadPool* inter_op_thread_pool,
const DataTransferManager& data_transfer_mgr,
const ExternalDataLoaderManager& external_data_loader_mgr,
const logging::Logger& logger,
profiling::Profiler& profiler,
const SessionOptions& sess_options,
Expand All @@ -78,6 +79,7 @@ SessionState::SessionState(Graph& graph,
thread_pool_(thread_pool),
inter_op_thread_pool_(inter_op_thread_pool),
data_transfer_mgr_(data_transfer_mgr),
external_data_loader_mgr_(external_data_loader_mgr),
sess_options_(sess_options),
prepacked_weights_container_(prepacked_weights_container)
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -1046,7 +1048,7 @@ Status SessionState::CreateSubgraphSessionState() {
auto subgraph_session_state =
std::make_unique<SessionState>(*subgraph, execution_providers_,
thread_pool_, inter_op_thread_pool_, data_transfer_mgr_,
logger_, profiler_, sess_options_,
external_data_loader_mgr_, logger_, profiler_, sess_options_,
prepacked_weights_container_, allocators_);

// Pass fused function manager to subgraph
Expand Down Expand Up @@ -1486,8 +1488,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
}
return Status::OK();
},
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func,
name_to_buffered_tensor_));
logger_, data_transfer_mgr_, external_data_loader_mgr_, *p_seq_exec_plan_, session_options,
memory_profile_func, name_to_buffered_tensor_));

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
// Record Weight allocation info on device
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "core/framework/allocation_planner.h"
#include "core/framework/callback.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/external_data_loader_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/stream_execution_context.h"
#include "core/framework/feeds_fetches_manager.h"
Expand Down Expand Up @@ -93,6 +94,7 @@ class SessionState {
concurrency::ThreadPool* thread_pool,
concurrency::ThreadPool* inter_op_thread_pool,
const DataTransferManager& data_transfer_mgr,
const ExternalDataLoaderManager& external_data_loader_mgr,
const logging::Logger& logger,
profiling::Profiler& profiler,
const SessionOptions& sess_options,
Expand Down Expand Up @@ -296,6 +298,8 @@ class SessionState {

const DataTransferManager& GetDataTransferMgr() const noexcept { return data_transfer_mgr_; }

const ExternalDataLoaderManager& GetExternalDataLoaderMgr() const noexcept { return external_data_loader_mgr_; }

InlinedVector<BufferUniquePtr>& GetMutableWeightsBuffers() noexcept { return weights_buffers_; }

const NodeIndexInfo& GetNodeIndexInfo() const;
Expand Down Expand Up @@ -513,6 +517,8 @@ class SessionState {

const DataTransferManager& data_transfer_mgr_;

const ExternalDataLoaderManager& external_data_loader_mgr_;

const SessionOptions& sess_options_;

std::optional<NodeIndexInfo> node_index_info_;
Expand Down
Loading

0 comments on commit d2a1b7a

Please sign in to comment.