From ca5c3d6d9ab9a94987899f8249e53021b80f7fcd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 Aug 2024 01:41:54 -0700 Subject: [PATCH 1/7] Introduce custom external data loader --- .../core/framework/execution_provider.h | 8 ++++ .../core/framework/external_data_loader.cc | 19 +++++++++ .../core/framework/external_data_loader.h | 41 +++++++++++++++++++ .../framework/external_data_loader_manager.cc | 29 +++++++++++++ .../framework/external_data_loader_manager.h | 28 +++++++++++++ onnxruntime/core/framework/session_state.cc | 8 ++-- onnxruntime/core/framework/session_state.h | 6 +++ .../core/framework/session_state_utils.cc | 19 +++++++-- .../core/framework/session_state_utils.h | 2 + .../core/framework/tensorprotoutils.cc | 25 +++++++++++ onnxruntime/core/framework/tensorprotoutils.h | 7 ++++ .../core/providers/js/external_data_loader.h | 31 ++++++++++++++ .../providers/js/js_execution_provider.cc | 6 +++ .../core/providers/js/js_execution_provider.h | 1 + .../providers/shared_library/provider_api.h | 1 + onnxruntime/core/session/inference_session.cc | 13 ++++++ onnxruntime/core/session/inference_session.h | 9 ++++ .../test/framework/allocation_planner_test.cc | 5 ++- .../test/framework/execution_frame_test.cc | 15 ++++--- .../test/framework/session_state_test.cc | 23 +++++++++-- onnxruntime/test/providers/memcpy_test.cc | 3 +- 21 files changed, 281 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/core/framework/external_data_loader.cc create mode 100644 onnxruntime/core/framework/external_data_loader.h create mode 100644 onnxruntime/core/framework/external_data_loader_manager.cc create mode 100644 onnxruntime/core/framework/external_data_loader_manager.h create mode 100644 onnxruntime/core/providers/js/external_data_loader.h diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 49c3d1bdd088a..2771619fd6696 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/framework/data_transfer.h" +#include "core/framework/external_data_loader.h" #include "core/framework/tensor.h" namespace onnxruntime { @@ -88,6 +89,13 @@ class IExecutionProvider { return nullptr; } + /** + * Returns an external data loader object that implements methods to load data from external sources. + */ + virtual std::unique_ptr GetExternalDataLoader() const { + return nullptr; + } + /** * Interface for performing kernel lookup within kernel registries. * Abstracts away lower-level details about kernel registries and kernel matching. diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc new file mode 100644 index 0000000000000..8431cce4f27ee --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/external_data_loader.h" +#ifndef SHARED_PROVIDER +#include "core/framework/tensor.h" +#endif + +namespace onnxruntime { + +common::Status IExternalDataLoader::LoadTensor([[maybe_unused]] const Env& env, + [[maybe_unused]] const std::filesystem::path& data_file_path, + [[maybe_unused]] FileOffsetType data_offset, + [[maybe_unused]] SafeInt data_length, + [[maybe_unused]] Tensor& tensor) const { + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h new file mode 100644 index 0000000000000..96945c4b15f9e --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/platform/env.h" + +struct OrtMemoryInfo; + +namespace onnxruntime { +#ifndef SHARED_PROVIDER +class Tensor; +#endif +class Stream; + +namespace common { +class Status; +} + +// Data transfer interface. +class IExternalDataLoader { + public: + virtual ~IExternalDataLoader() = default; + + virtual bool CanLoad(const OrtMemoryInfo& target_memory_info) const = 0; + + // Tensor should be already allocated with the correct memory info and size. + virtual common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader_manager.cc b/onnxruntime/core/framework/external_data_loader_manager.cc new file mode 100644 index 0000000000000..91161b1d3dd4c --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader_manager.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/external_data_loader_manager.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +using namespace common; + +Status ExternalDataLoaderManager::RegisterExternalDataLoader(std::unique_ptr external_data_loader) { + if (nullptr == external_data_loader) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "external_data_loader registered is nullptr."); + } + external_data_loaders_.push_back(std::move(external_data_loader)); + return Status::OK(); +} + +const IExternalDataLoader* ExternalDataLoaderManager::GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const { + for (auto& external_data_loader : external_data_loaders_) { + if (!external_data_loader->CanLoad(target_memory_info)) { + continue; + } + + return external_data_loader.get(); + } + return nullptr; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader_manager.h b/onnxruntime/core/framework/external_data_loader_manager.h new file mode 100644 index 0000000000000..e2970fc416f12 --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader_manager.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/common/common.h" +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { + +// The external data loader manager manages all registered external data loaders to allow custom +// external data loading implemented by excution providers. +class ExternalDataLoaderManager { + public: + ExternalDataLoaderManager() = default; + + common::Status RegisterExternalDataLoader(std::unique_ptr external_data_loader); + + const IExternalDataLoader* GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalDataLoaderManager); + + // It's assumed that external data loaders in this array have no overlap in terms of copying functionality. + std::vector> external_data_loaders_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index ddb0c3356e544..4df0370ac719e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -66,6 +66,7 @@ SessionState::SessionState(Graph& graph, concurrency::ThreadPool* thread_pool, concurrency::ThreadPool* inter_op_thread_pool, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const logging::Logger& logger, profiling::Profiler& profiler, const SessionOptions& sess_options, @@ -78,6 +79,7 @@ SessionState::SessionState(Graph& graph, thread_pool_(thread_pool), inter_op_thread_pool_(inter_op_thread_pool), data_transfer_mgr_(data_transfer_mgr), + external_data_loader_mgr_(external_data_loader_mgr), sess_options_(sess_options), prepacked_weights_container_(prepacked_weights_container) #ifdef ORT_ENABLE_STREAM @@ -1046,7 +1048,7 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, + external_data_loader_mgr_, logger_, profiler_, sess_options_, prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph @@ -1486,8 +1488,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& GetMutableWeightsBuffers() noexcept { return weights_buffers_; } const NodeIndexInfo& GetNodeIndexInfo() const; @@ -513,6 +517,8 @@ class SessionState { const DataTransferManager& data_transfer_mgr_; + const ExternalDataLoaderManager& external_data_loader_mgr_; + const SessionOptions& sess_options_; std::optional node_index_info_; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index b13b0cd27496d..4ae1ac30400d3 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -99,6 +99,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -113,6 +114,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); std::unique_ptr p_tensor; + if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -132,7 +134,16 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } - if (p_tensor->Location().device.Type() == OrtDevice::CPU) { + // Check if custom external data loader is available for the tensor + const onnxruntime::IExternalDataLoader* external_data_loader = nullptr; + if (utils::HasExternalData(tensor_proto)) { + external_data_loader = external_data_loader_mgr.GetExternalDataLoader(alloc->Info()); + } + + if (external_data_loader != nullptr) { + ORT_RETURN_IF_ERROR(utils::LoadExtDataToTensorFromTensorProto(env, proto_path, tensor_proto, + *external_data_loader, *p_tensor)); + } else if (p_tensor->Location().device.Type() == OrtDevice::CPU) { // deserialize directly to CPU tensor if (utils::HasExternalData(tensor_proto)) { // NB: The file containing external data for the tensor is mmap'd. If the tensor will be used on CPU we can @@ -201,7 +212,9 @@ common::Status SaveInitializedTensors( const std::vector& initializer_allocation_order, ITensorAllocator& planner, const SaveTensorFunction& save_tensor_func, - const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, + const logging::Logger& logger, + const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, @@ -333,7 +346,7 @@ common::Status SaveInitializedTensors( } Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, + default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 499222b6ec613..ab066044025da 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -23,6 +23,7 @@ class SessionState; class GraphViewer; class OrtValueNameIdxMap; class DataTransferManager; +class ExternalDataLoaderManager; class NodeArg; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) class MemoryInfo; @@ -45,6 +46,7 @@ common::Status SaveInitializedTensors( const SaveTensorFunction& save_tensor_func, const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index cbd53298ab2ad..82fd3cb1c01ab 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1095,6 +1095,31 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo return Status::OK(); } +Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const IExternalDataLoader& ext_data_loader, + Tensor& tensor) { + ORT_ENFORCE(utils::HasExternalData(tensor_proto)); + std::basic_string tensor_proto_dir; + if (!model_path.empty()) { + ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); + } + std::basic_string external_data_file_path; + FileOffsetType file_offset; + SafeInt raw_data_safe_len = 0; + ORT_RETURN_IF_ERROR( + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); + + ORT_RETURN_IF(file_offset < 0 || raw_data_safe_len != tensor.SizeInBytes(), + "External initializer: ", tensor_proto.name(), " offset: ", file_offset, + " size to read: ", static_cast(raw_data_safe_len), + " does not match the tensor size: ", tensor.SizeInBytes()); + ORT_RETURN_IF(external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag, + "Memory address tag is not supported by custom external data loader."); + + return ext_data_loader.LoadTensor(env, external_data_file_path, file_offset, raw_data_safe_len, tensor); +} + #define CASE_PROTO(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ ORT_RETURN_IF_ERROR( \ diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 2af1f080be7ee..227ba0706197e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -14,6 +14,7 @@ #include "core/common/safeint.h" #include "core/framework/endian_utils.h" #include "core/framework/allocator.h" +#include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" #include "core/framework/tensor_external_data_info.h" @@ -159,6 +160,12 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: OrtCallback& ext_data_deleter, Tensor* buffered_tensor = nullptr); +// Given a tensor proto with external data obtain a tensor using the specified custom external data loader. +common::Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const IExternalDataLoader& ext_data_loader, + Tensor& tensor); + // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the // the data location is external. i.e. it does not load the external data. diff --git a/onnxruntime/core/providers/js/external_data_loader.h b/onnxruntime/core/providers/js/external_data_loader.h new file mode 100644 index 0000000000000..df2126d086ba1 --- /dev/null +++ b/onnxruntime/core/providers/js/external_data_loader.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { +namespace js { + +class ExternalDataLoader : public IExternalDataLoader { + public: + ExternalDataLoader() {}; + ~ExternalDataLoader() {}; + + bool CanLoad(const OrtMemoryInfo& target_memory_info) const override { + return target_memory_info.device.Type() == OrtDevice::GPU && + target_memory_info.name == WEBGPU_BUFFER; + } + + common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const override { + return common::Status::OK(); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0ad62b87d33b5..e1a7f1d2a67bf 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -22,6 +22,7 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" #include "data_transfer.h" +#include "external_data_loader.h" namespace onnxruntime { @@ -762,6 +763,11 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer return std::make_unique(); } +std::unique_ptr JsExecutionProvider::GetExternalDataLoader() const { + // return std::make_unique(); + return nullptr; +} + JsExecutionProvider::~JsExecutionProvider() { } diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index efacf510e75df..966f9c6980212 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -48,6 +48,7 @@ class JsExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; + std::unique_ptr GetExternalDataLoader() const override; DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 2f54a04e15304..e5defb3060071 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -220,6 +220,7 @@ using NameMLValMap = std::unordered_map; #include "core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h" #include "core/providers/cpu/cpu_provider_shared.h" #include "core/framework/data_transfer.h" +#include "core/framework/external_data_loader.h" #include "core/framework/execution_provider.h" #include "provider_interfaces.h" #include "provider_wrappedtypes.h" diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5eed7c5c6f2b5..4564593297987 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -830,6 +830,14 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } + auto p_external_data_loader = p_exec_provider->GetExternalDataLoader(); + if (p_external_data_loader) { + auto st = external_data_loader_mgr_.RegisterExternalDataLoader(std::move(p_external_data_loader)); + if (!st.IsOK()) { + return st; + } + } + p_exec_provider->SetLogger(session_logger_); session_profiler_.AddEpProfilers(p_exec_provider->GetProfiler()); return execution_providers_.Add(provider_type, p_exec_provider); @@ -1731,6 +1739,7 @@ common::Status InferenceSession::Initialize() { GetIntraOpThreadPoolToUse(), GetInterOpThreadPoolToUse(), data_transfer_mgr_, + external_data_loader_mgr_, *session_logger_, session_profiler_, session_options_, @@ -2149,6 +2158,10 @@ const DataTransferManager& InferenceSession::GetDataTransferManager() const { return data_transfer_mgr_; } +const ExternalDataLoaderManager& InferenceSession::GetExternalDataLoaderManager() const { + return external_data_loader_mgr_; +} + common::Status InferenceSession::CheckShapes(const std::string& input_output_name, const TensorShape& input_output_shape, const TensorShape& expected_shape, const char* input_output_moniker) const { const auto shape_size = input_output_shape.NumDimensions(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9662095bf0ed3..8c22fac4dd0c5 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/execution_providers.h" #include "core/framework/framework_common.h" #include "core/framework/iexecutor.h" +#include "core/framework/external_data_loader_manager.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" @@ -454,6 +455,11 @@ class InferenceSession { */ const DataTransferManager& GetDataTransferManager() const; + /* + * Get the GetExternalDataLoaderManager associated with this session + */ + const ExternalDataLoaderManager& GetExternalDataLoaderManager() const; + /* * Get all the providers' options this session was initialized with. */ @@ -784,6 +790,9 @@ class InferenceSession { // Data transfer manager. DataTransferManager data_transfer_mgr_; + // External data loader manager. + ExternalDataLoaderManager external_data_loader_mgr_; + // Number of concurrently running executors std::atomic current_num_runs_ = 0; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 43d3782be3280..0105e90b5a24a 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -160,6 +160,7 @@ class PlannerTest : public ::testing::Test { ExecutionProviders execution_providers_; std::unique_ptr tp_; DataTransferManager dtm_; + ExternalDataLoaderManager edlm_; profiling::Profiler profiler_; std::unique_ptr sess_options_; std::unique_ptr state_; @@ -198,7 +199,7 @@ class PlannerTest : public ::testing::Test { sess_options_->enable_mem_pattern = false; sess_options_->use_deterministic_compute = false; sess_options_->enable_mem_reuse = true; - state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, + state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); } @@ -282,7 +283,7 @@ class PlannerTest : public ::testing::Test { } void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { - state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, + state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index b95fd0b726a4e..67a0e7fb05241 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -59,6 +59,7 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -67,7 +68,7 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); node->SetExecutionProviderType(xp_typ); @@ -143,6 +144,7 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -151,7 +153,7 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); node->SetExecutionProviderType(xp_typ); @@ -215,6 +217,7 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -223,7 +226,7 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); @@ -287,6 +290,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { // 1. prepare input DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -295,7 +299,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); @@ -402,10 +406,11 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions so; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, so); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index ed698ab920147..b94d24a1b180b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -61,6 +61,7 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider))); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -69,7 +70,7 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState s(graph, execution_providers, tp.get(), nullptr, dtm, + SessionState s(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); std::vector inputs; @@ -159,6 +160,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -167,7 +169,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, + SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); GraphPartitioner partitioner(krm, execution_providers); @@ -239,6 +241,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -250,7 +253,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1")); - SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, + SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); // Partition the graph @@ -300,6 +303,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -308,7 +312,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, + SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); // Partition the graph @@ -545,6 +549,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider))); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; std::unordered_map domain_to_version; @@ -573,6 +578,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -604,6 +610,7 @@ class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { ExecutionProviders execution_providers; std::unordered_map domain_to_version; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; KernelRegistryManager kernel_registry_manager; std::unique_ptr tp; @@ -661,6 +668,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -687,6 +695,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -734,6 +743,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -760,6 +770,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -809,6 +820,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -840,6 +852,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -895,6 +908,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -945,6 +959,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index b0cdb7dc97773..4efa359b4e589 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -47,6 +47,7 @@ TEST(MemcpyTest, copy1) { PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -55,7 +56,7 @@ TEST(MemcpyTest, copy1) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState s(model.MainGraph(), execution_providers, &tp, nullptr, dtm, + SessionState s(model.MainGraph(), execution_providers, &tp, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); From 1d22809a27857cbd7a180239a977cbe974507d41 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:00:14 -0700 Subject: [PATCH 2/7] impl --- .../core/framework/tensorprotoutils.cc | 62 ++-------- onnxruntime/core/providers/js/allocator.cc | 6 +- onnxruntime/core/providers/js/allocator.h | 15 +-- .../core/providers/js/external_data_loader.cc | 113 ++++++++++++++++++ .../core/providers/js/external_data_loader.h | 22 ++-- .../providers/js/js_execution_provider.cc | 7 +- onnxruntime/wasm/pre-jsep.js | 4 + 7 files changed, 150 insertions(+), 79 deletions(-) create mode 100644 onnxruntime/core/providers/js/external_data_loader.cc diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 82fd3cb1c01ab..e812581a7add6 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -10,6 +10,9 @@ #include #if defined(__wasm__) #include + +#include "core/providers/js/external_data_loader.h" + #endif #include @@ -1022,59 +1025,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo ext_data_buf = buffer.release(); ext_data_len = raw_data_safe_len; - // In WebAssembly, try use a simplified preloaded file map in WebAssembly when available. - auto err_code = EM_ASM_INT(({ - // If available, "Module.MountedFiles" is a Map for all preloaded files. - if (typeof Module == 'undefined' || !Module.MountedFiles) { - return 1; // "Module.MountedFiles" is not available. - } - let fileName = UTF8ToString($0 >>> 0); - if (fileName.startsWith('./')) { - fileName = fileName.substring(2); - } - const fileData = Module.MountedFiles.get(fileName); - if (!fileData) { - return 2; // File not found in preloaded files. - } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const buffer = $3 >>> 0; - - if (offset + length > fileData.byteLength) { - return 3; // Out of bounds. - } - - try { - // Copy the file data (fileData,offset,length) into WebAssembly memory - // (HEAPU8,buffer,length). - HEAPU8.set(fileData.subarray(offset, offset + length), buffer); - return 0; - } catch { - return 4; - } - }), - external_data_file_path.c_str(), - static_cast(file_offset), - static_cast(raw_data_safe_len), - ext_data_buf); - const char* err_msg; - switch (err_code) { - case 0: - return Status::OK(); - case 1: - err_msg = "Module.MountedFiles is not available."; - break; - case 2: - err_msg = "File not found in preloaded files."; - break; - case 3: - err_msg = "Out of bounds."; - break; - default: - err_msg = "Unknown error occurred in memory copy."; - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", external_data_file_path, - "\", error: ", err_msg); + ORT_RETURN_IF_ERROR(js::LoadExternalData(env, + external_data_file_path, + file_offset, + ext_data_len, + js::ExternalDataLoadType::CPU, + ext_data_buf)); #else // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to // manually check file size first. diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc index 574c507222a5c..d37346a166b03 100644 --- a/onnxruntime/core/providers/js/allocator.cc +++ b/onnxruntime/core/providers/js/allocator.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -void* JsCustomAllocator::Alloc(size_t size) { +void* WebGpuAllocator::Alloc(size_t size) { if (size == 0) { return nullptr; } @@ -20,14 +20,14 @@ void* JsCustomAllocator::Alloc(size_t size) { return p; } -void JsCustomAllocator::Free(void* p) { +void WebGpuAllocator::Free(void* p) { if (p != nullptr) { size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); stats_.bytes_in_use -= size; } } -void JsCustomAllocator::GetStats(AllocatorStats* stats) { +void WebGpuAllocator::GetStats(AllocatorStats* stats) { *stats = stats_; } diff --git a/onnxruntime/core/providers/js/allocator.h b/onnxruntime/core/providers/js/allocator.h index 267015b2ea58d..aafb0bb22da7e 100644 --- a/onnxruntime/core/providers/js/allocator.h +++ b/onnxruntime/core/providers/js/allocator.h @@ -9,20 +9,11 @@ namespace onnxruntime { namespace js { -class JsCPUAllocator : public CPUAllocator { +class WebGpuAllocator : public IAllocator { public: - JsCPUAllocator() - : CPUAllocator( - OrtMemoryInfo("JsCPUAllocator", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, OrtMemTypeCPU)) {}; -}; - -class JsCustomAllocator : public IAllocator { - public: - JsCustomAllocator() + WebGpuAllocator() : IAllocator( - OrtMemoryInfo("JsCustomAllocator", OrtAllocatorType::OrtDeviceAllocator, + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) { } diff --git a/onnxruntime/core/providers/js/external_data_loader.cc b/onnxruntime/core/providers/js/external_data_loader.cc new file mode 100644 index 0000000000000..408c67be833dd --- /dev/null +++ b/onnxruntime/core/providers/js/external_data_loader.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "external_data_loader.h" + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace js { + +bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { + return target_memory_info.device.Type() == OrtDevice::CPU || + (target_memory_info.device.Type() == OrtDevice::GPU && + target_memory_info.name == WEBGPU_BUFFER); +} + +common::Status ExternalDataLoader::LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const { + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } + + return LoadExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); +} + +common::Status LoadExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data) { + auto err_code = EM_ASM_INT(({ + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (typeof Module == 'undefined' || !Module.MountedFiles) { + return 1; // "Module.MountedFiles" is not available. + } + let fileName = UTF8ToString($0 >>> 0); + if (fileName.startsWith('./')) { + fileName = fileName.substring(2); + } + const fileData = Module.MountedFiles.get(fileName); + if (!fileData) { + return 2; // File not found in preloaded files. + } + const offset = $1 >>> 0; + const length = $2 >>> 0; + const dataIdOrBuffer = $3 >>> 0; + const loadType = $4; + + if (offset + length > fileData.byteLength) { + return 3; // Out of bounds. + } + + try { + const data = fileData.subarray(offset, offset + length); + switch (loadType) { + case 0: + // Load external data to CPU memory. + // Copy the file data (fileData,offset,length) into WebAssembly memory + // (HEAPU8,buffer,length). + HEAPU8.set(data, dataIdOrBuffer); + break; + case 1: + // Load external data to GPU. + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + break; + default: + return 4; // Unknown error occurred in memory copy. + } + return 0; + } catch { + return 4; + } + }), + data_file_path.c_str(), + static_cast(data_offset), + static_cast(data_length), + tensor_data, + static_cast(load_type)); + const char* err_msg; + switch (err_code) { + case 0: + return Status::OK(); + case 1: + err_msg = "Module.MountedFiles is not available."; + break; + case 2: + err_msg = "File not found in preloaded files."; + break; + case 3: + err_msg = "Out of bounds."; + break; + default: + err_msg = "Unknown error occurred in memory copy."; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, + "\", error: ", err_msg); +} + + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/external_data_loader.h b/onnxruntime/core/providers/js/external_data_loader.h index df2126d086ba1..56e235a1b40c8 100644 --- a/onnxruntime/core/providers/js/external_data_loader.h +++ b/onnxruntime/core/providers/js/external_data_loader.h @@ -8,24 +8,32 @@ namespace onnxruntime { namespace js { +enum class ExternalDataLoadType { + CPU = 0, + WEBGPU_BUFFER = 1, +}; + class ExternalDataLoader : public IExternalDataLoader { public: ExternalDataLoader() {}; ~ExternalDataLoader() {}; - bool CanLoad(const OrtMemoryInfo& target_memory_info) const override { - return target_memory_info.device.Type() == OrtDevice::GPU && - target_memory_info.name == WEBGPU_BUFFER; - } + bool CanLoad(const OrtMemoryInfo& target_memory_info) const override; common::Status LoadTensor(const Env& env, const std::filesystem::path& data_file_path, FileOffsetType data_offset, SafeInt data_length, - Tensor& tensor) const override { - return common::Status::OK(); - } + Tensor& tensor) const override; }; +// Entry point for loading external data implementation. +common::Status LoadExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data); + } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index e1a7f1d2a67bf..5be965348788a 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -703,9 +703,9 @@ JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, co std::vector JsExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo customAllocatorCreationInfo([&](int) { - return std::make_unique(); + return std::make_unique(); }, - 0, false); // TODO(leca): REVIEW: need JsCPUAllocator? + 0, false); return std::vector{CreateAllocator(customAllocatorCreationInfo)}; } @@ -764,8 +764,7 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer } std::unique_ptr JsExecutionProvider::GetExternalDataLoader() const { - // return std::make_unique(); - return nullptr; + return std::make_unique(); } JsExecutionProvider::~JsExecutionProvider() { diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 1cb7c6f5d8250..70ed295887994 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -198,5 +198,9 @@ Module['jsepInit'] = (name, params) => { Module['jsepOnRunStart'] = sessionId => { return backend['onRunStart'](sessionId); }; + + Module.jsepUploadExternalBuffer = (dataId, buffer) => { + backend['upload'](dataId, buffer); + }; } }; From 9368cb684c62d9c09bee3e3d1b5aa455890a57e5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 7 Aug 2024 19:07:08 -0700 Subject: [PATCH 3/7] resolve comments and fix linter --- .../core/framework/execution_provider.h | 6 + .../framework/external_data_loader_manager.h | 2 +- .../core/providers/js/external_data_loader.cc | 145 +++++++++--------- 3 files changed, 79 insertions(+), 74 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 2771619fd6696..a5b5d2edde46c 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -91,6 +91,12 @@ class IExecutionProvider { /** * Returns an external data loader object that implements methods to load data from external sources. + * + * By default, framework will handle external data loading by loading the data into CPU memory and then copying + * it to the target device if required. So in most cases, it's not necessary to override this method. Specifically, + * in WebAssembly build, because the memory is limited and Web platform supports loading data from external sources + * directly into GPU memory, this method is overridden to provide a custom external data loader to avoid the extra + * CPU memory usage. */ virtual std::unique_ptr GetExternalDataLoader() const { return nullptr; diff --git a/onnxruntime/core/framework/external_data_loader_manager.h b/onnxruntime/core/framework/external_data_loader_manager.h index e2970fc416f12..38881405c87ff 100644 --- a/onnxruntime/core/framework/external_data_loader_manager.h +++ b/onnxruntime/core/framework/external_data_loader_manager.h @@ -10,7 +10,7 @@ namespace onnxruntime { // The external data loader manager manages all registered external data loaders to allow custom -// external data loading implemented by excution providers. +// external data loading implemented by execution providers. class ExternalDataLoaderManager { public: ExternalDataLoaderManager() = default; diff --git a/onnxruntime/core/providers/js/external_data_loader.cc b/onnxruntime/core/providers/js/external_data_loader.cc index 408c67be833dd..34062bb42a649 100644 --- a/onnxruntime/core/providers/js/external_data_loader.cc +++ b/onnxruntime/core/providers/js/external_data_loader.cc @@ -21,15 +21,15 @@ common::Status ExternalDataLoader::LoadTensor(const Env& env, FileOffsetType data_offset, SafeInt data_length, Tensor& tensor) const { - ExternalDataLoadType load_type; - if (tensor.Location().device.Type() == OrtDevice::CPU) { - load_type = ExternalDataLoadType::CPU; - } else if (tensor.Location().device.Type() == OrtDevice::GPU && - tensor.Location().name == WEBGPU_BUFFER) { - load_type = ExternalDataLoadType::WEBGPU_BUFFER; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); - } + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } return LoadExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); } @@ -40,74 +40,73 @@ common::Status LoadExternalData(const Env& env, SafeInt data_length, ExternalDataLoadType load_type, void* tensor_data) { - auto err_code = EM_ASM_INT(({ - // If available, "Module.MountedFiles" is a Map for all preloaded files. - if (typeof Module == 'undefined' || !Module.MountedFiles) { - return 1; // "Module.MountedFiles" is not available. - } - let fileName = UTF8ToString($0 >>> 0); - if (fileName.startsWith('./')) { - fileName = fileName.substring(2); - } - const fileData = Module.MountedFiles.get(fileName); - if (!fileData) { - return 2; // File not found in preloaded files. - } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const dataIdOrBuffer = $3 >>> 0; - const loadType = $4; + auto err_code = EM_ASM_INT(({ + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (typeof Module == 'undefined' || !Module.MountedFiles) { + return 1; // "Module.MountedFiles" is not available. + } + let fileName = UTF8ToString($0 >>> 0); + if (fileName.startsWith('./')) { + fileName = fileName.substring(2); + } + const fileData = Module.MountedFiles.get(fileName); + if (!fileData) { + return 2; // File not found in preloaded files. + } + const offset = $1 >>> 0; + const length = $2 >>> 0; + const dataIdOrBuffer = $3 >>> 0; + const loadType = $4; - if (offset + length > fileData.byteLength) { - return 3; // Out of bounds. - } + if (offset + length > fileData.byteLength) { + return 3; // Out of bounds. + } - try { - const data = fileData.subarray(offset, offset + length); - switch (loadType) { - case 0: - // Load external data to CPU memory. - // Copy the file data (fileData,offset,length) into WebAssembly memory - // (HEAPU8,buffer,length). - HEAPU8.set(data, dataIdOrBuffer); - break; - case 1: - // Load external data to GPU. - Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); - break; - default: - return 4; // Unknown error occurred in memory copy. - } - return 0; - } catch { - return 4; + try { + const data = fileData.subarray(offset, offset + length); + switch (loadType) { + case 0: + // Load external data to CPU memory. + // Copy the file data (fileData,offset,length) into WebAssembly memory + // (HEAPU8,buffer,length). + HEAPU8.set(data, dataIdOrBuffer); + break; + case 1: + // Load external data to GPU. + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + break; + default: + return 4; // Unknown error occurred in memory copy. } - }), - data_file_path.c_str(), - static_cast(data_offset), - static_cast(data_length), - tensor_data, - static_cast(load_type)); - const char* err_msg; - switch (err_code) { - case 0: - return Status::OK(); - case 1: - err_msg = "Module.MountedFiles is not available."; - break; - case 2: - err_msg = "File not found in preloaded files."; - break; - case 3: - err_msg = "Out of bounds."; - break; - default: - err_msg = "Unknown error occurred in memory copy."; - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, - "\", error: ", err_msg); + return 0; + } catch { + return 4; + } + }), + data_file_path.c_str(), + static_cast(data_offset), + static_cast(data_length), + tensor_data, + static_cast(load_type)); + const char* err_msg; + switch (err_code) { + case 0: + return Status::OK(); + case 1: + err_msg = "Module.MountedFiles is not available."; + break; + case 2: + err_msg = "File not found in preloaded files."; + break; + case 3: + err_msg = "Out of bounds."; + break; + default: + err_msg = "Unknown error occurred in memory copy."; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, + "\", error: ", err_msg); } - } // namespace js } // namespace onnxruntime From 4e831ae21366b264abf7217643803e58f8248a26 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:23:54 -0700 Subject: [PATCH 4/7] support non-webgpu wasm build --- onnxruntime/core/framework/tensorprotoutils.cc | 2 +- .../core/providers/js/js_execution_provider.cc | 2 +- .../js_external_data_loader.cc} | 12 ++++++++---- .../js_external_data_loader.h} | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) rename onnxruntime/{core/providers/js/external_data_loader.cc => wasm/js_external_data_loader.cc} (94%) rename onnxruntime/{core/providers/js/external_data_loader.h => wasm/js_external_data_loader.h} (93%) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index e812581a7add6..8befe2ad6a058 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -11,7 +11,7 @@ #if defined(__wasm__) #include -#include "core/providers/js/external_data_loader.h" +#include "wasm/js_external_data_loader.h" #endif diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 5be965348788a..392efd17b88e2 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -22,7 +22,7 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" #include "data_transfer.h" -#include "external_data_loader.h" +#include "wasm/js_external_data_loader.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/js/external_data_loader.cc b/onnxruntime/wasm/js_external_data_loader.cc similarity index 94% rename from onnxruntime/core/providers/js/external_data_loader.cc rename to onnxruntime/wasm/js_external_data_loader.cc index 34062bb42a649..01df6ddfd971e 100644 --- a/onnxruntime/core/providers/js/external_data_loader.cc +++ b/onnxruntime/wasm/js_external_data_loader.cc @@ -3,7 +3,7 @@ #include -#include "external_data_loader.h" +#include "js_external_data_loader.h" #include "core/framework/tensor.h" @@ -11,9 +11,11 @@ namespace onnxruntime { namespace js { bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { - return target_memory_info.device.Type() == OrtDevice::CPU || - (target_memory_info.device.Type() == OrtDevice::GPU && - target_memory_info.name == WEBGPU_BUFFER); + return target_memory_info.device.Type() == OrtDevice::CPU +#if defined(USE_JSEP) + || (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER) +#endif + ; } common::Status ExternalDataLoader::LoadTensor(const Env& env, @@ -24,9 +26,11 @@ common::Status ExternalDataLoader::LoadTensor(const Env& env, ExternalDataLoadType load_type; if (tensor.Location().device.Type() == OrtDevice::CPU) { load_type = ExternalDataLoadType::CPU; +#if defined(USE_JSEP) } else if (tensor.Location().device.Type() == OrtDevice::GPU && tensor.Location().name == WEBGPU_BUFFER) { load_type = ExternalDataLoadType::WEBGPU_BUFFER; +#endif } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); } diff --git a/onnxruntime/core/providers/js/external_data_loader.h b/onnxruntime/wasm/js_external_data_loader.h similarity index 93% rename from onnxruntime/core/providers/js/external_data_loader.h rename to onnxruntime/wasm/js_external_data_loader.h index 56e235a1b40c8..10385adff6f85 100644 --- a/onnxruntime/core/providers/js/external_data_loader.h +++ b/onnxruntime/wasm/js_external_data_loader.h @@ -27,7 +27,7 @@ class ExternalDataLoader : public IExternalDataLoader { Tensor& tensor) const override; }; -// Entry point for loading external data implementation. +// Entry point for loading external data implementation using inline JavaScript. common::Status LoadExternalData(const Env& env, const std::filesystem::path& data_file_path, FileOffsetType data_offset, From 523a06f0528441fe75d4a25a35d3de1a2dd0eaf1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:38:30 -0700 Subject: [PATCH 5/7] fix build and rearrange default loader --- .../core/framework/external_data_loader.cc | 81 ++++++++++++ .../core/framework/external_data_loader.h | 19 +++ .../core/framework/tensorprotoutils.cc | 15 +-- .../core/providers/js/external_data_loader.cc | 42 +++++++ .../providers/js/external_data_loader.h} | 13 -- .../providers/js/js_execution_provider.cc | 2 +- onnxruntime/wasm/js_external_data_loader.cc | 116 ------------------ 7 files changed, 149 insertions(+), 139 deletions(-) create mode 100644 onnxruntime/core/providers/js/external_data_loader.cc rename onnxruntime/{wasm/js_external_data_loader.h => core/providers/js/external_data_loader.h} (58%) delete mode 100644 onnxruntime/wasm/js_external_data_loader.cc diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc index 8431cce4f27ee..ea6c499829391 100644 --- a/onnxruntime/core/framework/external_data_loader.cc +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -5,6 +5,9 @@ #ifndef SHARED_PROVIDER #include "core/framework/tensor.h" #endif +#if defined(__wasm__) +#include +#endif namespace onnxruntime { @@ -16,4 +19,82 @@ common::Status IExternalDataLoader::LoadTensor([[maybe_unused]] const Env& env, ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); } +#if defined(__wasm__) + +common::Status LoadWebAssemblyExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data) { + auto err_code = EM_ASM_INT(({ + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (typeof Module == 'undefined' || !Module.MountedFiles) { + return 1; // "Module.MountedFiles" is not available. + } + let fileName = UTF8ToString($0 >>> 0); + if (fileName.startsWith('./')) { + fileName = fileName.substring(2); + } + const fileData = Module.MountedFiles.get(fileName); + if (!fileData) { + return 2; // File not found in preloaded files. + } + const offset = $1 >>> 0; + const length = $2 >>> 0; + const dataIdOrBuffer = $3 >>> 0; + const loadType = $4; + + if (offset + length > fileData.byteLength) { + return 3; // Out of bounds. + } + + try { + const data = fileData.subarray(offset, offset + length); + switch (loadType) { + case 0: + // Load external data to CPU memory. + // Copy the file data (fileData,offset,length) into WebAssembly memory + // (HEAPU8,buffer,length). + HEAPU8.set(data, dataIdOrBuffer); + break; + case 1: + // Load external data to GPU. + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + break; + default: + return 4; // Unknown error occurred in memory copy. + } + return 0; + } catch { + return 4; + } + }), + data_file_path.c_str(), + static_cast(data_offset), + static_cast(data_length), + tensor_data, + static_cast(load_type)); + const char* err_msg; + switch (err_code) { + case 0: + return Status::OK(); + case 1: + err_msg = "Module.MountedFiles is not available."; + break; + case 2: + err_msg = "File not found in preloaded files."; + break; + case 3: + err_msg = "Out of bounds."; + break; + default: + err_msg = "Unknown error occurred in memory copy."; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, + "\", error: ", err_msg); +} + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h index 96945c4b15f9e..117da7d0a4afa 100644 --- a/onnxruntime/core/framework/external_data_loader.h +++ b/onnxruntime/core/framework/external_data_loader.h @@ -38,4 +38,23 @@ class IExternalDataLoader { Tensor& tensor) const; }; +#if defined(__wasm__) + +enum class ExternalDataLoadType { + CPU = 0, +#if defined(USE_JSEP) + WEBGPU_BUFFER = 1, +#endif +}; + +// Entry point for loading external data implementation using inline JavaScript. +common::Status LoadWebAssemblyExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data); + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 8befe2ad6a058..e5d64a5063c07 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -10,9 +10,6 @@ #include #if defined(__wasm__) #include - -#include "wasm/js_external_data_loader.h" - #endif #include @@ -1025,12 +1022,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo ext_data_buf = buffer.release(); ext_data_len = raw_data_safe_len; - ORT_RETURN_IF_ERROR(js::LoadExternalData(env, - external_data_file_path, - file_offset, - ext_data_len, - js::ExternalDataLoadType::CPU, - ext_data_buf)); + ORT_RETURN_IF_ERROR(LoadWebAssemblyExternalData(env, + external_data_file_path, + file_offset, + ext_data_len, + ExternalDataLoadType::CPU, + ext_data_buf)); #else // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to // manually check file size first. diff --git a/onnxruntime/core/providers/js/external_data_loader.cc b/onnxruntime/core/providers/js/external_data_loader.cc new file mode 100644 index 0000000000000..193b373cf3696 --- /dev/null +++ b/onnxruntime/core/providers/js/external_data_loader.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "external_data_loader.h" + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace js { + +bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { + return target_memory_info.device.Type() == OrtDevice::CPU +#if defined(USE_JSEP) + || (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER) +#endif + ; +} + +common::Status ExternalDataLoader::LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const { + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; +#if defined(USE_JSEP) + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; +#endif + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } + + return LoadWebAssemblyExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/wasm/js_external_data_loader.h b/onnxruntime/core/providers/js/external_data_loader.h similarity index 58% rename from onnxruntime/wasm/js_external_data_loader.h rename to onnxruntime/core/providers/js/external_data_loader.h index 10385adff6f85..5f35ed62bbcc1 100644 --- a/onnxruntime/wasm/js_external_data_loader.h +++ b/onnxruntime/core/providers/js/external_data_loader.h @@ -8,11 +8,6 @@ namespace onnxruntime { namespace js { -enum class ExternalDataLoadType { - CPU = 0, - WEBGPU_BUFFER = 1, -}; - class ExternalDataLoader : public IExternalDataLoader { public: ExternalDataLoader() {}; @@ -27,13 +22,5 @@ class ExternalDataLoader : public IExternalDataLoader { Tensor& tensor) const override; }; -// Entry point for loading external data implementation using inline JavaScript. -common::Status LoadExternalData(const Env& env, - const std::filesystem::path& data_file_path, - FileOffsetType data_offset, - SafeInt data_length, - ExternalDataLoadType load_type, - void* tensor_data); - } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 392efd17b88e2..5be965348788a 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -22,7 +22,7 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" #include "data_transfer.h" -#include "wasm/js_external_data_loader.h" +#include "external_data_loader.h" namespace onnxruntime { diff --git a/onnxruntime/wasm/js_external_data_loader.cc b/onnxruntime/wasm/js_external_data_loader.cc deleted file mode 100644 index 01df6ddfd971e..0000000000000 --- a/onnxruntime/wasm/js_external_data_loader.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "js_external_data_loader.h" - -#include "core/framework/tensor.h" - -namespace onnxruntime { -namespace js { - -bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { - return target_memory_info.device.Type() == OrtDevice::CPU -#if defined(USE_JSEP) - || (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER) -#endif - ; -} - -common::Status ExternalDataLoader::LoadTensor(const Env& env, - const std::filesystem::path& data_file_path, - FileOffsetType data_offset, - SafeInt data_length, - Tensor& tensor) const { - ExternalDataLoadType load_type; - if (tensor.Location().device.Type() == OrtDevice::CPU) { - load_type = ExternalDataLoadType::CPU; -#if defined(USE_JSEP) - } else if (tensor.Location().device.Type() == OrtDevice::GPU && - tensor.Location().name == WEBGPU_BUFFER) { - load_type = ExternalDataLoadType::WEBGPU_BUFFER; -#endif - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); - } - - return LoadExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); -} - -common::Status LoadExternalData(const Env& env, - const std::filesystem::path& data_file_path, - FileOffsetType data_offset, - SafeInt data_length, - ExternalDataLoadType load_type, - void* tensor_data) { - auto err_code = EM_ASM_INT(({ - // If available, "Module.MountedFiles" is a Map for all preloaded files. - if (typeof Module == 'undefined' || !Module.MountedFiles) { - return 1; // "Module.MountedFiles" is not available. - } - let fileName = UTF8ToString($0 >>> 0); - if (fileName.startsWith('./')) { - fileName = fileName.substring(2); - } - const fileData = Module.MountedFiles.get(fileName); - if (!fileData) { - return 2; // File not found in preloaded files. - } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const dataIdOrBuffer = $3 >>> 0; - const loadType = $4; - - if (offset + length > fileData.byteLength) { - return 3; // Out of bounds. - } - - try { - const data = fileData.subarray(offset, offset + length); - switch (loadType) { - case 0: - // Load external data to CPU memory. - // Copy the file data (fileData,offset,length) into WebAssembly memory - // (HEAPU8,buffer,length). - HEAPU8.set(data, dataIdOrBuffer); - break; - case 1: - // Load external data to GPU. - Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); - break; - default: - return 4; // Unknown error occurred in memory copy. - } - return 0; - } catch { - return 4; - } - }), - data_file_path.c_str(), - static_cast(data_offset), - static_cast(data_length), - tensor_data, - static_cast(load_type)); - const char* err_msg; - switch (err_code) { - case 0: - return Status::OK(); - case 1: - err_msg = "Module.MountedFiles is not available."; - break; - case 2: - err_msg = "File not found in preloaded files."; - break; - case 3: - err_msg = "Out of bounds."; - break; - default: - err_msg = "Unknown error occurred in memory copy."; - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, - "\", error: ", err_msg); -} - -} // namespace js -} // namespace onnxruntime From 4f5c1d05860a29792b70bb674e0916d8e28771d8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 Aug 2024 00:42:52 -0700 Subject: [PATCH 6/7] fix --- onnxruntime/core/framework/session_state_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index d8b90cb350c11..ecde18aa1bf5f 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -115,7 +115,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); std::unique_ptr p_tensor; - auto memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo(); + auto& memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo(); auto device_type = memory_info.device.Type(); if (utils::HasExternalData(tensor_proto)) { From 628ebf788597dd1f31909c72d0ff9109d34aa083 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 17 Aug 2024 01:44:34 -0700 Subject: [PATCH 7/7] fix linter --- onnxruntime/core/framework/session_state_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index ecde18aa1bf5f..2c74805c57dce 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -125,7 +125,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); ORT_RETURN_IF_ERROR(utils::LoadExtDataToTensorFromTensorProto(env, proto_path, tensor_proto, - *external_data_loader, *p_tensor)); + *external_data_loader, *p_tensor)); auto ml_tensor = DataTypeImpl::GetType(); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());