From f1ead4d49d15166f25928036b3ad6cc3bcd4f163 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 22 Jul 2024 18:02:55 -0700 Subject: [PATCH] hack ext data location to reduce qd matmul memory usage --- .../core/optimizer/graph_transformer_utils.h | 7 +- onnxruntime/core/framework/session_state.cc | 3 +- onnxruntime/core/framework/session_state.h | 10 ++ .../core/framework/session_state_utils.cc | 33 ++++-- .../core/framework/session_state_utils.h | 3 +- .../core/framework/tensorprotoutils.cc | 34 +++++- onnxruntime/core/framework/tensorprotoutils.h | 29 +++-- .../core/optimizer/graph_transformer_utils.cc | 12 ++- .../selectors_actions/qdq_actions.cc | 102 ++++++++++-------- .../selectors_actions/qdq_actions.h | 5 +- .../qdq_selector_action_transformer.cc | 30 ++++-- .../qdq_selector_action_transformer.h | 3 +- onnxruntime/core/session/inference_session.cc | 14 ++- 13 files changed, 197 insertions(+), 88 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 0bb5c7432f0a7..e32777439769d 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -9,6 +9,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" +#include "core/framework/tensor.h" #include "core/optimizer/graph_transformer.h" #include "core/platform/threadpool.h" @@ -51,7 +52,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -81,7 +83,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 42fb7b392283a..e463976b65209 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1486,7 +1486,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif + std::unordered_map>* GetMutableBufferedTensors() { + return &name_to_buffered_tensor_; + } + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -562,6 +566,12 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif + + // Holds the tensors which provide memory buffer for TensorProtos + // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is + // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release + // a instance's memory which may result large memory consumption, which is a tradeoff for speed. + std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 059de8e3c8c4a..39b317a7eaae9 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -61,17 +61,23 @@ struct ExtDataValueDeleter { // given a tensor proto with external data return an OrtValue with a tensor for // that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter +// by the OrtValue's deleter. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter) { + Tensor& tensor, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); void* ext_data_buf = nullptr; SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter)); + ext_data_buf, ext_data_len, ext_data_deleter, + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -83,11 +89,16 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, return common::Status::OK(); } +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, - bool use_device_allocator_for_initializers = false) { + bool use_device_allocator_for_initializers = false, + Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); @@ -123,7 +134,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, ext_data_deleter)); + ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; @@ -154,7 +166,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st std::optional scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); } else { ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); @@ -187,7 +199,8 @@ common::Status SaveInitializedTensors( const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func) { + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -307,9 +320,13 @@ common::Status SaveInitializedTensors( bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + Tensor* p_tensor = buffered_tensors.find(name) != buffered_tensors.end() + ? buffered_tensors[name].release() + : nullptr; + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, - use_device_allocator_for_initializers); + use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af44c35fbb7f5..64ea657c44e77 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -44,7 +44,8 @@ common::Status SaveInitializedTensors( const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, - const MemoryProfileFunction& memory_profile_func); + const MemoryProfileFunction& memory_profile_func, + std::unordered_map>& buffered_tensors); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4ecd61962d797..913d74218189c 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -987,7 +987,8 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter) { + SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1003,7 +1004,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo // the value in location is the memory address of the data ext_data_buf = reinterpret_cast(file_offset); ext_data_len = raw_data_safe_len; - ext_data_deleter = OrtCallback{nullptr, nullptr}; + if (buffered_tensor) { + ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, + reinterpret_cast(buffered_tensor)}; + } else { + ext_data_deleter = OrtCallback{nullptr, nullptr}; + } } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1241,7 +1247,9 @@ ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto return CApiElementTypeFromProtoType(tensor_proto.data_type()); } -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name) { +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer) { // Set name, dimensions, type, and data of the TensorProto. ONNX_NAMESPACE::TensorProto tensor_proto; @@ -1259,6 +1267,26 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: for (; f < end; ++f) { *mutable_string_data->Add() = *f; } + } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { + const auto* raw_data = tensor.DataRaw(); + ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); + static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. + // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the + // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. + auto offset = gsl::narrow(reinterpret_cast(raw_data)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(tensor.SizeInBytes())); } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index aabfc0487f3e0..c3ba57bd81d1c 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -114,14 +114,22 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); -/** Creates a TensorProto from a Tensor. - @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. - @param[in] tensor_proto_name the name of the TensorProto. - @return the TensorProto. - - Note: Method currently requires that data is in little-endian format. +/** + * @brief Creates a TensorProto from a Tensor. + * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. + * @param[in] tensor_proto_name the name of the TensorProto. + * @param[in] use_tensor_buffer the tensor proto is set to use external location, with + * 'location' set to onnxruntime::utils::kTensorProtoMemoryAddressTag + * 'offset' set to tensor's memory location, and 'length' set to tensor's + * memory size. The caller is responsible to maintain the lifetime of + * the allocated memory buffer. Use with caution. + * @return the TensorProto. + * + * Note: Method currently requires that data is in little-endian format. */ -ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name); +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, + const std::string& tensor_proto_name, + bool use_tensor_buffer = false); ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); @@ -141,10 +149,15 @@ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_A // Given a tensor proto with external data obtain a pointer to the data and its length. // The ext_data_deleter argument is updated with a callback that owns/releases the data. +// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and +// buffered_tensor is not null, buffered_tensor holds the real buffer pointed +// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter +// should release the buffer when tensor_proto is released. common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter); + OrtCallback& ext_data_deleter, + Tensor* buffered_tensor = nullptr); // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 7da65f18ccacb..aa47fcd51b556 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -189,7 +189,8 @@ InlinedVector> GenerateTransformers( const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -309,7 +310,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -419,7 +421,8 @@ InlinedVector> GenerateTransformersForMinimalB const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -444,7 +447,8 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool)); + intra_op_thread_pool, + p_buffered_tensors)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 74fecb0427e14..f4487c4175a5d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -275,8 +275,10 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( + int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -286,7 +288,8 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool} { + intra_op_thread_pool_{intra_op_thread_pool}, + p_buffered_tensors_{p_buffered_tensors} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -311,6 +314,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { + ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -338,24 +342,32 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; - Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(weight_arg->Name() + "_T"), - std::vector{N, quant_num, blob_bytes}); - Initializer scale_dst(static_cast(scale_src.data_type()), - graph.GenerateNodeArgName(scale_arg->Name() + "_T"), - std::vector{N * quant_num}); - std::optional zp_dst; + auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); + auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); + std::optional> zp_src_ptr; + auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); + auto weight_dst_ptr = std::make_unique(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + std::make_shared()); + auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_dst_ptr = std::make_unique(scale_type, + TensorShape{N * quant_num}, + std::make_shared()); + std::optional zp_dst_name; + std::optional> zp_dst_ptr; if (zp_tensor_proto) { - zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_src_ptr = std::make_unique(*zp_tensor_proto, graph.ModelPath()); + zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{N * ((quant_num + 1) / 2)}, + std::make_shared()); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)}); + zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); + zp_dst_ptr = std::make_unique(uint8_type, + TensorShape{N * ((quant_num + 1) / 2)}, + std::make_shared()); + memset(zp_dst_ptr.value()->MutableDataRaw(), 0, zp_dst_ptr.value()->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -363,10 +375,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr.value()->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -376,10 +388,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr.value()->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -391,10 +403,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr.value()->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -405,10 +417,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, - weight_dst.data(), - scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst_ptr->MutableData(), + scale_dst_ptr->MutableData(), + zp_dst_ptr ? zp_dst_ptr.value()->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -417,28 +429,32 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - ONNX_NAMESPACE::TensorProto weight_T_tp; - ONNX_NAMESPACE::TensorProto scale_T_tp; + auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); std::optional zp_T_tp; - // TODO(fajin): external_data to memory location to avoid arena allocation - // https://github.com/microsoft/onnxruntime/pull/12465 - weight_dst.ToProto(weight_T_tp); - scale_dst.ToProto(scale_T_tp); - if (zp_dst) { - zp_T_tp.emplace(); - zp_dst->ToProto(zp_T_tp.value()); + if (zp_dst_ptr) { + zp_T_tp = utils::TensorToTensorProto(*zp_dst_ptr.value(), zp_dst_name.value(), true); } auto& input_defs = replacement_node.MutableInputDefs(); input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, + "Failed to add buffered tensor ", + weight_dst_name); input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, + "Failed to add buffered tensor ", + scale_dst_name); if (zp_T_tp) { input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); + ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name.value(), std::move(zp_dst_ptr.value())).second, + "Failed to add buffered tensor ", + zp_dst_name.value()); } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 47821619db65a..1c80fab5ff7da 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -9,6 +9,7 @@ #include "core/optimizer/selectors_actions/actions.h" #include "core/platform/threadpool.h" +#include "core/framework/tensor.h" namespace onnxruntime { @@ -84,7 +85,8 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool); + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -103,6 +105,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 17e66a3953b97..0eee7e569069b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -230,7 +230,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -238,7 +239,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr selector = std::make_unique(); @@ -295,9 +297,11 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) { +SelectorActionRegistry CreateSelectorActionRegistry( + bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -311,20 +315,24 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool); + intra_op_thread_pool, + p_buffered_tensors); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, - const SatApplyContextVariant& apply_context, - int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool) +QDQSelectorActionTransformer::QDQSelectorActionTransformer( + bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, + intra_op_thread_pool, p_buffered_tensors), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index ba636f76d1900..942c558a0ad76 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -25,7 +25,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr, + std::unordered_map>* p_buffered_tensors = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3fd6e84e0e5ce..fa382e12c0009 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1610,7 +1610,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool) { + concurrency::ThreadPool* intra_op_thread_pool, + std::unordered_map>* p_buffered_tensors) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1618,7 +1619,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable, intra_op_thread_pool); + optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2007,7 +2008,8 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse())); + cpu_ep, GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3170,7 +3172,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3180,7 +3183,8 @@ common::Status InferenceSession::AddPredefinedTransformers( : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); + GetIntraOpThreadPoolToUse(), + session_state_->GetMutableBufferedTensors()); } }();