Skip to content

Commit

Permalink
enable serialize prepacked weights into data file (#22256)
Browse files Browse the repository at this point in the history
### Description
part of #21448
This change is intend to save CPU memory during model load for
inference.
Added session option save_prepacked_constant_initializers, with
save_prepacked_constant_initializers turn on:
1. optimize model with inference session, prepacked external initializer
will be saved into data file.
2. load optimized model and external data file with prepacked
initializer, no prepack is needed
3. run inference with optimized model and data file

Tested with model Phi-3-mini-instruct-onnx,
with ORT 1.12.0:

![image](https://github.com/user-attachments/assets/3c0337be-f340-4bb7-8f9f-30f3552072ef)

with this change:

![image](https://github.com/user-attachments/assets/23282990-2e1e-4a1f-92de-afa8ed7e6a43)

Peak memory usage dropped from **5.438 GB to 2.726GB**.
This change takes advantage of ORT loads external initializer with mmap
on CPU. Prepack will use extra memory on heap, omit prepack process can
save this part of memory (roughly same size as external initializers).

next step:
Change all the kernels on CPU with PrePack method implemented and test
properly. Will do in next PR.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
frank-dong-ms authored Oct 25, 2024
1 parent 4ed5bec commit c5b6be0
Show file tree
Hide file tree
Showing 72 changed files with 872 additions and 137 deletions.
22 changes: 22 additions & 0 deletions include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class OpKernel {
// the allocator tied to the session if the kernel owns the pre-packed buffer or an
// allocator shared between sessions if the pre-packed buffer is to be shared across sessions
// (i.e.) the kernel does not own the buffer.
// @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file.
// @param is_packed: Set it to true if the kernel packed the tensor or to false
// The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
// and the original initialized constant tensor will be released and not accessible anymore in
Expand All @@ -88,6 +89,7 @@ class OpKernel {

virtual Status
PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
bool, /*save_prepacked_initializers*/
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
return Status::OK();
Expand Down Expand Up @@ -129,6 +131,26 @@ class OpKernel {
return Status::OK();
}

// Override this function to get pre-packed tensors from this kernel.
// Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
// ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
// @param input_idx : The index of input we prepacked before and intend to get packed tensor back.
// Please refer to matmul_nbits kernel for a complete example.
virtual std::optional<Tensor> GetPrePackTensor(int /*input_idx*/) {
return std::nullopt;
}

// Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer.
// Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
// ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
// Please refer to matmul_nbits kernel for a complete example.
// @param input_idx : The input index of the tensor in this kernel.
// @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor
// to restore prepacked weight buffer.
virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) {
return Status::OK();
}

const OrtDevice GetDevice(OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
Expand Down
29 changes: 27 additions & 2 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
#endif

// Since one constant initializer could be used by different kernels
// and prepacked differently, use an unordered_map to store prepacked
// initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>>
typedef std::unordered_map<std::string, std::unordered_map<std::string, ONNX_NAMESPACE::TensorProto>> PrePackedTensorProtoToSave;

#if !defined(ORT_MINIMAL_BUILD)
/** Gets the GraphProto representation of this Graph. */
const ONNX_NAMESPACE::GraphProto& ToGraphProto();
Expand Down Expand Up @@ -1182,18 +1187,26 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
@param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
in the external file. Initializer smaller than this threshold are included in the onnx file.
@param align_info offset alignment info.
@param save_prepacked_constant_initializers whether to save prepacked initializer into external data file.
If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file,
we keep constant initializer as it is.
@param pre_packed_initializers struct used to store all the prepacked initializers.
@returns GraphProto serialization of the graph.
*/
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
const std::filesystem::path& model_file_path,
size_t initializer_size_threshold,
const OffsetAlignmentInfo& align_info) const;
const OffsetAlignmentInfo& align_info,
bool save_prepacked_constant_initializers,
PrePackedTensorProtoToSave& pre_packed_initializers) const;

ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
const std::filesystem::path& model_file_path,
size_t initializer_size_threshold) const {
OffsetAlignmentInfo default_options;
return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options);
PrePackedTensorProtoToSave pre_packed_initializers;
return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options,
false, pre_packed_initializers);
}

/** Gets the ISchemaRegistry instances being used with this Graph. */
Expand Down Expand Up @@ -1508,6 +1521,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
private:
void InitializeStateFromModelFileGraphProto();

// Private method used to setup external initializer properly during model save,
// this external initializer could be oroginal initializer or prepacked initializer.
static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info,
size_t tensor_bytes_size,
int64_t& external_offset,
std::ofstream& external_stream,
gsl::span<const uint8_t> raw_data,
ONNX_NAMESPACE::TensorProto& output_proto,
const std::filesystem::path& external_file_path,
const ONNX_NAMESPACE::TensorProto& initializer,
bool is_prepacked);

// Add node with specified <node_proto>.
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
"session.optimized_model_external_initializers_file_name";

// Use this config when save prepacked constant initializers to onnx external data file.
// Default is not save prepacked initializers to onnx data file.
// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1")
static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
"session.save_prepacked_constant_initializers";

// Use this config to control the minimum size of the initializer when externalizing it during serialization
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Attention : public OpKernel, public AttentionCPUBase {
Status Compute(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Expand Down Expand Up @@ -101,6 +102,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,

template <typename T>
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
/* The PrePack() massages the weights to speed up Compute(), there is an option to
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class QAttention : public OpKernel, public AttentionCPUBase {
Status Compute(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
bool& /*out*/ is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Expand Down Expand Up @@ -58,6 +59,7 @@ QAttention<T>::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC

template <typename T>
Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
if (1 != input_idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}

Status PrePack(const Tensor& tensor, int input_idx,
AllocatorPtr alloc, /*out*/ bool& is_packed,
AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
Expand Down Expand Up @@ -91,6 +91,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector<BufferUniquePtr>& prepacke
}

Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,19 @@ class MatMulNBits final : public OpKernel {
Status Compute(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx);

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) override;

std::optional<Tensor> GetPrePackTensor(int /*input_idx*/) override;

Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override;

private:
const size_t K_;
const size_t N_;
Expand All @@ -119,6 +126,8 @@ class MatMulNBits final : public OpKernel {
size_t packed_b_size_{0};
IAllocatorUniquePtr<float> scales_fp32_{};
IAllocatorUniquePtr<float> bias_fp32_{};
std::optional<Tensor> packed_tensor_{std::nullopt};
MLDataType prepack_tensor_data_type_;

bool has_zp_input_{false};

Expand Down Expand Up @@ -148,8 +157,22 @@ class MatMulNBits final : public OpKernel {
}
};

template <typename T1>
void MatMulNBits<T1>::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) {
if (input_idx == InputIndex::B) {
prepack_tensor_data_type_ = tensor.DataType();
}

TensorShapeVector weights_dims = {static_cast<int64_t>((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1};
packed_tensor_ = Tensor(prepack_tensor_data_type_,
TensorShape(weights_dims),
packed_b_.get(),
OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator));
}

template <typename T1>
Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
Expand Down Expand Up @@ -185,11 +208,16 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
#endif // MLAS_TARGET_AMD64_IX86
}

if (save_prepacked_initializers) {
ConvertPrepackWeightIntoTensor(tensor, input_idx);
}

return Status::OK();
}

template <>
Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
Expand Down Expand Up @@ -239,6 +267,34 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
#endif // MLAS_TARGET_AMD64_IX86
}

if (save_prepacked_initializers) {
ConvertPrepackWeightIntoTensor(tensor, input_idx);
}

return Status::OK();
}

template <typename T1>
std::optional<Tensor> MatMulNBits<T1>::GetPrePackTensor(int input_idx) {
// For this kernel, prepack is performed on input_B, and possibly scales, zeros_points.
// During compute process, scales and zeros_points will keep as it is and only use prepacked
// buffer to replace input_B.
// Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize
// the latest one. So, we need to always return packed_tensor_ here not only for input_B.
ORT_UNUSED_PARAMETER(input_idx);
return std::move(packed_tensor_);
}

template <typename T1>
Status MatMulNBits<T1>::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) {
if (input_idx == 1) {
// pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state,
// session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so
// pass empty/default buffer deleter here.
// const_cast here is temporary, will fix in follow up PR.
packed_b_ = BufferUniquePtr(const_cast<void*>(pre_packed_tensor.DataRaw()), BufferDeleter());
}

return Status::OK();
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SkipLayerNorm final : public OpKernel {
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* p_op_kernel_context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;

private:
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
}

Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/,
bool /*save_prepacked_initializers*/,
bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class GroupNorm final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
bool& is_packed, PrePackedWeights* prepacked_weights) override;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a
}

Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase {

public:
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) {
}

Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool /*save_prepacked_initializers*/,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /* prepacked_weights */) {
is_packed = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class QOrderedMatMul final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool save_prepacked_initializers,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ struct SessionOptions {
// enable profiling for this session.
bool enable_profiling = false;

// save pre-packed constant external initializers instead of original initializers to onnxruntime data file.
// Only useful for models run on PC with CPU so ORT could load prepacked weights directly from
// ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory.
bool save_prepacked_constant_initializers = false;

// Non empty filepath enables serialization of the transformed optimized model to the specified filepath.
//
// Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly
Expand Down Expand Up @@ -191,6 +196,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_
<< " execution_mode:" << session_options.execution_mode
<< " execution_order:" << session_options.execution_order
<< " enable_profiling:" << session_options.enable_profiling
<< " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers
<< " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath)
<< " enable_mem_pattern:" << session_options.enable_mem_pattern
<< " enable_mem_reuse:" << session_options.enable_mem_reuse
Expand Down
Loading

0 comments on commit c5b6be0

Please sign in to comment.