Skip to content

Commit

Permalink
hack ext data location to reduce qd matmul memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 23, 2024
1 parent 17e9ea6 commit f1ead4d
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 88 deletions.
7 changes: 5 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
#include "core/framework/tensor.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"

Expand Down Expand Up @@ -51,7 +52,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -81,7 +83,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);

Check warning on line 87 in include/onnxruntime/core/optimizer/graph_transformer_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/optimizer/graph_transformer_utils.h:87: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 87 in include/onnxruntime/core/optimizer/graph_transformer_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/optimizer/graph_transformer_utils.h:87: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
}
return Status::OK();
},
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func));
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func,
name_to_buffered_tensor_));

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
// Record Weight allocation info on device
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ class SessionState {
const InlinedHashSet<NodeIndex>* GetToBeExecutedRange(gsl::span<int const> fetch_mlvalue_idxs) const;
#endif

std::unordered_map<std::string, std::unique_ptr<Tensor>>* GetMutableBufferedTensors() {
return &name_to_buffered_tensor_;
}

Status FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
const KernelRegistryManager& kernel_registry_manager,
bool remove_initializers = true,
Expand Down Expand Up @@ -562,6 +566,12 @@ class SessionState {
// flag to indicate whether current session using any EP that create device stream dynamically.
bool has_device_stream_enabled_ep_ = false;
#endif

// Holds the tensors which provide memory buffer for TensorProtos
// Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is
// allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release
// a instance's memory which may result large memory consumption, which is a tradeoff for speed.
std::unordered_map<std::string, std::unique_ptr<Tensor>> name_to_buffered_tensor_;

Check warning on line 574 in onnxruntime/core/framework/session_state.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.h:574: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace onnxruntime
33 changes: 25 additions & 8 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,23 @@ struct ExtDataValueDeleter {

// given a tensor proto with external data return an OrtValue with a tensor for
// that data; the pointers for the tensor data and the tensor itself are owned
// by the OrtValue's deleter
// by the OrtValue's deleter.
// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and
// buffered_tensor is not null, buffered_tensor holds the real buffer pointed
// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter
// should release the buffer when tensor_proto is released.
static inline common::Status ExtDataTensorProtoToTensor(const Env& env,
const std::basic_string<PATH_CHAR_TYPE>& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto,
Tensor& tensor, OrtCallback& ext_data_deleter) {
Tensor& tensor, OrtCallback& ext_data_deleter,
Tensor* buffered_tensor = nullptr) {
ORT_ENFORCE(utils::HasExternalData(tensor_proto));

void* ext_data_buf = nullptr;
SafeInt<size_t> ext_data_len = 0;
ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto,
ext_data_buf, ext_data_len, ext_data_deleter));
ext_data_buf, ext_data_len, ext_data_deleter,
buffered_tensor));

// NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be
// avoided if the Tensor class implements the do-nothing behavior when given a
Expand All @@ -83,11 +89,16 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env,
return common::Status::OK();
}

// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and
// buffered_tensor is not null, buffered_tensor holds the real buffer pointed
// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter
// should release the buffer when tensor_proto is released.
static common::Status DeserializeTensorProto(const Env& env, const std::basic_string<PATH_CHAR_TYPE>& proto_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m,
const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc,
OrtValue& ort_value, const DataTransferManager& data_transfer_mgr,
bool use_device_allocator_for_initializers = false) {
bool use_device_allocator_for_initializers = false,
Tensor* buffered_tensor = nullptr) {
if (bool(alloc) == (m != nullptr)) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"DeserializeTensorProto() takes either pre-allocated buffer or an allocator!");
Expand Down Expand Up @@ -123,7 +134,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
// utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called
// TensorProtoToTensor it would copy the data, causing unnecessary overhead
OrtCallback ext_data_deleter;
ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, ext_data_deleter));
ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor,
ext_data_deleter, buffered_tensor));

ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()};

Expand Down Expand Up @@ -154,7 +166,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
std::optional<ScopedOrtCallbackInvoker> scoped_ort_callback_invoker;
if (utils::HasExternalData(tensor_proto)) {
ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor,
ext_data_deleter));
ext_data_deleter, buffered_tensor));
scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter);
} else {
ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor));
Expand Down Expand Up @@ -187,7 +199,8 @@ common::Status SaveInitializedTensors(
const logging::Logger& logger, const DataTransferManager& data_transfer_mgr,
const ExecutionPlanBase& exec_plan,
const SessionOptions& session_options,
const MemoryProfileFunction& memory_profile_func) {
const MemoryProfileFunction& memory_profile_func,
std::unordered_map<std::string, std::unique_ptr<Tensor>>& buffered_tensors) {

Check warning on line 203 in onnxruntime/core/framework/session_state_utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state_utils.cc:203: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

Check warning on line 203 in onnxruntime/core/framework/session_state_utils.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state_utils.cc:203: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
LOGS(logger, INFO) << "Saving initialized tensors.";
ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated.");

Expand Down Expand Up @@ -307,9 +320,13 @@ common::Status SaveInitializedTensors(
bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

Tensor* p_tensor = buffered_tensors.find(name) != buffered_tensors.end()
? buffered_tensors[name].release()
: nullptr;

Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc,
default_cpu_alloc, ort_value, data_transfer_mgr,
use_device_allocator_for_initializers);
use_device_allocator_for_initializers, p_tensor);
if (!st.IsOK()) {
std::ostringstream oss;
oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage();
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/session_state_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ common::Status SaveInitializedTensors(
const DataTransferManager& data_transfer_mgr,
const ExecutionPlanBase& exec_plan,
const SessionOptions& session_options,
const MemoryProfileFunction& memory_profile_func);
const MemoryProfileFunction& memory_profile_func,
std::unordered_map<std::string, std::unique_ptr<Tensor>>& buffered_tensors);

Check warning on line 48 in onnxruntime/core/framework/session_state_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state_utils.h:48: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 48 in onnxruntime/core/framework/session_state_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state_utils.h:48: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

Check warning on line 48 in onnxruntime/core/framework/session_state_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state_utils.h:48: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph,
SessionState& session_state,
Expand Down
34 changes: 31 additions & 3 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,8 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p

Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf,
SafeInt<size_t>& ext_data_len, OrtCallback& ext_data_deleter) {
SafeInt<size_t>& ext_data_len, OrtCallback& ext_data_deleter,
Tensor* buffered_tensor) {
ORT_ENFORCE(utils::HasExternalData(tensor_proto));
std::basic_string<ORTCHAR_T> tensor_proto_dir;
if (!model_path.empty()) {
Expand All @@ -1003,7 +1004,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo
// the value in location is the memory address of the data
ext_data_buf = reinterpret_cast<void*>(file_offset);
ext_data_len = raw_data_safe_len;
ext_data_deleter = OrtCallback{nullptr, nullptr};
if (buffered_tensor) {
ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast<Tensor*>(p); },
reinterpret_cast<void*>(buffered_tensor)};
} else {
ext_data_deleter = OrtCallback{nullptr, nullptr};
}
} else {
#if defined(__wasm__)
ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296,
Expand Down Expand Up @@ -1241,7 +1247,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto
return CApiElementTypeFromProtoType(tensor_proto.data_type());
}

ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) {
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
const std::string& tensor_proto_name,
bool use_tensor_buffer) {
// Set name, dimensions, type, and data of the TensorProto.
ONNX_NAMESPACE::TensorProto tensor_proto;

Expand All @@ -1259,6 +1267,26 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std:
for (; f < end; ++f) {
*mutable_string_data->Add() = *f;
}
} else if (use_tensor_buffer && tensor.SizeInBytes() > 127) {
const auto* raw_data = tensor.DataRaw();
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);

// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
// use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
auto offset = gsl::narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(raw_data));

ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("location");
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("offset");
entry->set_value(std::to_string(offset));
entry = tensor_proto.mutable_external_data()->Add();
entry->set_key("length");
entry->set_value(std::to_string(tensor.SizeInBytes()));
} else {
utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes());
}
Expand Down
29 changes: 21 additions & 8 deletions onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,22 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path&
const ONNX_NAMESPACE::TensorProto& tensor_proto,
Tensor& tensor);

/** Creates a TensorProto from a Tensor.
@param[in] tensor the Tensor whose data and shape will be used to create the TensorProto.
@param[in] tensor_proto_name the name of the TensorProto.
@return the TensorProto.
Note: Method currently requires that data is in little-endian format.
/**
* @brief Creates a TensorProto from a Tensor.
* @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto.
* @param[in] tensor_proto_name the name of the TensorProto.
* @param[in] use_tensor_buffer the tensor proto is set to use external location, with
* 'location' set to onnxruntime::utils::kTensorProtoMemoryAddressTag
* 'offset' set to tensor's memory location, and 'length' set to tensor's
* memory size. The caller is responsible to maintain the lifetime of
* the allocated memory buffer. Use with caution.
* @return the TensorProto.
*
* Note: Method currently requires that data is in little-endian format.
*/
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name);
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
const std::string& tensor_proto_name,
bool use_tensor_buffer = false);

ONNXTensorElementDataType CApiElementTypeFromProtoType(int type);
ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
Expand All @@ -141,10 +149,15 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A

// Given a tensor proto with external data obtain a pointer to the data and its length.
// The ext_data_deleter argument is updated with a callback that owns/releases the data.
// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and
// buffered_tensor is not null, buffered_tensor holds the real buffer pointed
// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter
// should release the buffer when tensor_proto is released.
common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path,
const ONNX_NAMESPACE::TensorProto& tensor_proto,
void*& ext_data_buf, SafeInt<size_t>& ext_data_len,
OrtCallback& ext_data_deleter);
OrtCallback& ext_data_deleter,
Tensor* buffered_tensor = nullptr);

// Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer
// If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
Expand Down Expand Up @@ -309,7 +310,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
SatApplyContextVariant{},
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
p_buffered_tensors));
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
Expand Down Expand Up @@ -419,7 +421,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool saving = std::holds_alternative<SatRuntimeOptimizationSaveContext>(apply_context);

Expand All @@ -444,7 +447,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
apply_context,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
p_buffered_tensors));
}

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));
Expand Down
Loading

0 comments on commit f1ead4d

Please sign in to comment.