diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e295dfa203ae5..5f0100fad95a2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0e60b4622f2fb..71b0def659741 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -470,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -855,7 +855,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } } // namespace +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel { block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; @@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } #if defined(ORT_NEURAL_SPEED) if (!all_constant_) { @@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); @@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); @@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const uint8_t* b_data = b->Data(); const size_t ldb = helper.Ldb(true); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 0000000000000..f92e59e990ba5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 0000000000000..5061ac5c800a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..cd6593352008b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx[kb_idx * block_size + i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +153,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 015df70c8ec3c..1cec6f6a12f1c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,77 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp(), - UseTF32())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..f5c2c6c4e4fdf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e33ce20737f80..f06a3785f362d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3343,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits - - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3377,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index eb7bbec997d59..a1916e806c5c0 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -65,7 +65,7 @@ def __init__( self, calibration_data_reader: CalibrationDataReader, percdamp=0.01, - blocksize=128, + block_size=128, actorder=False, mse=False, perchannel=True, @@ -79,7 +79,7 @@ def __init__( a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: percent of the average Hessian diagonal to use for dampening. - blocksize (int, optional): + block_size (int, optional): channel number in one block to execute a GPTQ quantization iteration. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. @@ -93,42 +93,285 @@ def __init__( ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp - self.blocksize = blocksize + self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel -class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" +class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + block_size=128, + bits=4, + axis=1, + ): + """ + This is a class for HQQ algorithm Weight Only Quant Configuration. + HQQ algorithm quant weight without needing calibrate data. + + Args: + block_size (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + bits (int, optional): + how many bits to represent weight. + axis (int, optional): + 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + """ + super().__init__( + algorithm="HQQ", + ) + self.block_size = block_size + self.bits = bits + self.axis = axis + +class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model: ModelProto | str, - block_size: int, - is_symmetric: bool, + block_size: int = 128, + is_symmetric: bool = False, accuracy_level: int | None = None, - nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, ): - if nodes_to_exclude is None: - nodes_to_exclude = [] - self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) - self.model_path = model if isinstance(model, str) else None + super().__init__(algorithm="DEFAULT") self.block_size = block_size self.is_symmetric = is_symmetric + self.bits = 4 self.accuracy_level = accuracy_level - self.nodes_to_exclude = set(nodes_to_exclude) - self.algo_config = algo_config + + +def is_divisible(val1, val2): + return int(val2 * np.ceil(val1 / val2)) == val1 + + +class HQQWeightOnlyQuantizer: + def __init__( + self, + config: HQQWeightOnlyQuantConfig, + ): + self.config = config + + # Proximal solver || weight - dequantize(quantize(weight))||_p^p + @staticmethod + def optimize_weights( + tensor, + scale, + zero, + min_max: list[int], + axis: int = 0, + opt_params: dict = None, # noqa: RUF013 + verbose=False, + ): + import torch + + opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + + dtype = torch.float16 if tensor.is_cuda else torch.float32 + w_f = tensor.to(dtype) + scale = scale.to(dtype) + zero = zero.to(dtype) + + if lp_norm == 1: + + def shrink_op(x, beta): + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + + else: + + def shrink_op(x, beta, p=lp_norm): + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) + ) + + best_error = 1e4 + for i in range(iters): + w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) + w_r = (w_q - zero) / scale + w_e = shrink_op(w_f - w_r, beta) + zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(w_f - w_r).mean()) + if verbose: + print(i, np.round(current_error, 6)) + if current_error < best_error: + best_error = current_error + else: + break + + del w_f, w_q, w_r, w_e + + return scale, zero @staticmethod - def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None + def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): + if pack_tensor.shape[0] == ori_int_tensor.shape[0]: + ori_int_tensor = ori_int_tensor.T + pack_tensor = pack_tensor.T + if bits in [2, 4, 8]: + compress_ratio = pack_tensor.element_size() * 8 // bits + for j in range(0, compress_ratio): + pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + # from Official implementation of Half-Quadratic Quantization (HQQ) + def quantize_internal( + self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 + ): + import torch + + weight = tensor.float() + ori_shape = weight.shape + + pad_len = (group_size - ori_shape[axis] % group_size) % group_size + if axis == 1: + weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0) + else: + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0) + shape = weight.shape + + # Reshape for grouping + if (group_size is not None) and channel_wise: + weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1]) + + # Get min/max values + if channel_wise is False: + _min, _max = weight.min(), weight.max() + optimize = False + else: + _min = weight.min(axis=axis, keepdim=True)[0] + _max = weight.max(axis=axis, keepdim=True)[0] + + max_v = 2**bits - 1 + min_v = 0 + min_max = [min_v, max_v] + + # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. + # clamp to avoid half-precision problems + scale = (max_v / (_max - _min)).clamp(max=2e4) + #!!!!!!!!!!!!!!! + min_max_axis = _max - _min + if (min_max_axis == 0).sum().item() > 0: + min_max_axis[min_max_axis == 0] = max_v + scale = (max_v / min_max_axis).clamp(max=2e4) + zero = -_min * scale + + if round_zero: + zero = torch.round(zero) + + # Fine-tune weights + if optimize: + scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis) + + # Quantize + # Necessary for fake quantization backprop + w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).int() + + scale = 1.0 / scale + if axis == 1: + scale = scale.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + else: + scale = scale.reshape(-1, shape[-1]) + zero = zero.reshape(-1, shape[-1]) + # cleanup + del weight, _min, _max + + return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) + + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + if node.op_type != "MatMul": + return node # only care about MatMul for now + import torch + + logger.info(f"start to quantize {node.name} ...") + inputB = node.input[1] # noqa: N806 + b_pb, bs_graph = get_initializer(inputB, graph_stack) + if b_pb is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + b_array = onnx.numpy_helper.to_array(b_pb) + if len(b_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + b_array_torch = torch.from_numpy(b_array) + if torch.cuda.is_available(): + b_array_torch = b_array_torch.cuda() + quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( + b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size + ) + quant_weight_torch = quant_weight_torch.contiguous() + scales_torch = scales_torch.contiguous() + zero_points_torch = zero_points_torch.contiguous() + + packed_torch = torch.zeros( + (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2), + dtype=torch.uint8, + device=quant_weight_torch.device, + ) + self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) + scales = scales_torch.cpu().numpy() + zero_points = zero_points_torch.cpu().numpy() + b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) + b_quant.name = b_pb.name + "_Q4" + for input in bs_graph.input: + if input.name == inputB: + bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = b_pb.name + "_scales" + bs_graph.initializer.extend([b_quant, scales_tensor]) + + input_names = [node.input[0], b_quant.name, scales_tensor.name] + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = b_pb.name + "_zero_points" + bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = b_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = self.config.bits + kwargs["block_size"] = self.config.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + +def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + +class DefaultWeightOnlyQuantizer: + def __init__(self, config: DefaultWeightOnlyQuantConfig): + self.config = config def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: """4b quantize fp32 weight to a blob""" @@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = fp32weight.shape - block_size = self.block_size + block_size = self.config.block_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size @@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - if node.name in self.nodes_to_exclude: - logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - return node - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight @@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) Bs_graph.initializer.extend([B_quant, scales_tensor]) input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.is_symmetric: + if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = B.name + "_zero_points" Bs_graph.initializer.extend([zp_tensor]) @@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 - kwargs["block_size"] = self.block_size - if self.accuracy_level is not None: + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( @@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) return matmul_q4_node + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__( + self, + model: ModelProto | str, + block_size: int = 128, + is_symmetric: bool = False, + accuracy_level: int | None = None, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, + ): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None + self.block_size = block_size + self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level + self.nodes_to_exclude = set(nodes_to_exclude) + self.node_quantizer = None + if algo_config is None: + algo_config = DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + ) + self.algo_config = algo_config + if algo_config.algorithm == "HQQ": + self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "DEFAULT": + self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + out_node = None + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + out_node = node + elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": + out_node = self.node_quantizer.quantize(node, graph_stack) + else: + out_node = self.node_quantizer.quantize(node, graph_stack) + new_nodes.append(out_node) graph.ClearField("node") graph.node.extend(new_nodes) @@ -300,7 +578,7 @@ def inc_dataloader(): from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize kwargs["percdamp"] = self.algo_config.percdamp - kwargs["blocksize"] = self.algo_config.blocksize + kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder kwargs["mse"] = self.algo_config.mse kwargs["perchannel"] = self.algo_config.perchannel @@ -316,7 +594,7 @@ def inc_dataloader(): logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): - if self.algo_config is None: + if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() @@ -327,7 +605,6 @@ def process(self): has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - self._process_subgraph(graph_stack) self.model.clean_initializers() else: @@ -366,6 +643,14 @@ def parse_args(): parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization") + parser.add_argument( + "--quant_method", + default="default", + type=str, + choices=["default", "hqq"], + help="the algorithm used to quantize weight", + ) + parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") parser.add_argument( "--symmetric", required=False, @@ -411,12 +696,24 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) + if args.quant_method == "hqq": + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) + elif args.quant_method == "default": + quant_config = DefaultWeightOnlyQuantConfig( + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + ) + elif args.quant_method == "rtn": + quant_config = RTNWeightOnlyQuantConfig() + elif args.quant_method == "gptq": + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) + else: + raise ValueError(f"Unsupported quantization method: {args.quant_method}") + quant = MatMul4BitsQuantizer( model=model, - block_size=args.block_size, - is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, + algo_config=quant_config, ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d294fd4e2b0e0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD +#include #include "core/common/span_utils.h" #include "core/framework/tensor.h" @@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { + bool has_zeropoint, bool use_float16, bool has_g_idx = false, + bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddAttribute("block_size", block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", accuracy_level); + auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + + test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); - } + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); + } test.AddOutput("Y", {M, N}, expected_vals); if (accuracy_level == 4) { test.SetOutputAbsErr("Y", 0.1f); @@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) { for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); + RunTest(M, N, K, block_size, accuracy_level, false, false, true); + RunTest(M, N, K, block_size, accuracy_level, true, false, false, false); } #endif } @@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, 0, false, true); - RunTest(M, N, K, block_size, 0, true, true); + for (auto has_gidx : {true, false}) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx); + RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); + } } } } @@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f); } } } diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index c1bbb49f10c7e..b30282f2ab41f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -358,6 +358,7 @@ def check_model_correctness( model_onnx = onnx.load(f) ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_target_evaluator = False with open(model_path_to_check, "rb") as f: model_check = onnx.load(f) @@ -413,7 +414,7 @@ def check_model_correctness( check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. - if check_reference_evaluator and onnx_recent_enough: + if check_target_evaluator and onnx_recent_enough: if op_matmul: reference_new_ops = [QLinearMatMul] else: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 73dae08af8ece..88e5052db4e2e 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -125,7 +125,10 @@ def quant_test( from onnxruntime.quantization import matmul_4bits_quantizer model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric + ) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -165,6 +168,9 @@ def quant_test_with_algo( elif algorithm == "GPTQ": # test GPTQ algorithm algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + elif algorithm == "HQQ": + # test HQQ algorithm + algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) @@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_hqq_algo(self): + if not find_spec("torch"): + self.skipTest("skip test_hqq_quant since torch is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main()