diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index e295dfa203ae5..5f0100fad95a2 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.
- Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- - n_blocks_per_col = (K + block_size - 1) / block_size
- - blob_size = block_size / 8 * bits
+ Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+ - n_blocks_per_col = (K + block_size - 1) / block_size
+ - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
+ For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
+ - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
+ 4bit example:
+ |.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
+ - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
+ 3bit example:
+ |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
+ The last uint_8 may have some bits unused.
- For a block blob. It is stored in format:
- struct Blob {
- uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
- uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
- uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
- }
Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
- Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- - [(N * n_blocks_per_col + 1) / 2] if bits <=4
- - [N * n_blocks_per_col] if bits > 4
-
+ Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
+ - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
+ If zero_points has same type as A, it's not packed and has the same shape as Scales.
#### Version
@@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4)
+#### Inputs (3 - 5)
- A : T1
- The input tensor, not quantized
- B : T2
-- 1-dimensional data blob
+- 1 or 2 dimensional data blob
- scales : T1
- quantization scale
-- zero_points (optional) : T2
+- zero_points (optional) : T3
- quantization zero points
+- g_idx (optional) : T4
+- group_idx
#### Outputs
@@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
- T1 : tensor(float), tensor(float16)
- Constrain input and output types to float/half_float tensors.
-- T2 : tensor(uint8)
-- Constrain quantized weight types to uint8.
+- T2 : tensor(uint8), tensor(int32)
+- Constrain quantized weight types to uint8/int32.
+- T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+- Constrain quantized zero point types to uint8/int32/float16/float.
+- T4 : tensor(int32)
+- the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 0e60b4622f2fb..71b0def659741 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -470,7 +470,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
-|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
+|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)|
@@ -855,7 +855,7 @@ Do not modify directly.*
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
-|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
+|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 166f5c8f52f54..602dd98d8c0d6 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -1,6 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
+
+#include
+#include
+
+#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
@@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
}
} // namespace
+bool GetType(const NodeArg& node_arg, int32_t& type) {
+ type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
+ const auto* type_proto = node_arg.TypeAsProto();
+ if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) {
+ return false;
+ }
+
+ type = type_proto->tensor_type().elem_type();
+ return true;
+}
+
class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
@@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel {
block_size_{narrow(info.GetAttr("block_size"))},
nbits_{narrow(info.GetAttr("bits"))},
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} {
+ const auto& node = info.node();
+ auto input_defs = node.InputDefs();
+ // g_idx
+ if (input_defs.size() > 4) {
+ act_order_ = true;
+ }
+ int32_t type;
+ if (input_defs.size() > 3 && GetType(*input_defs[3], type)) {
+ zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
+ }
+
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
@@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel {
const size_t N_;
const size_t block_size_;
const size_t nbits_;
+ bool act_order_{false};
+ bool zero_point_is_not_quant_{false};
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr packed_b_;
@@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
-
+ if (act_order_ || zero_point_is_not_quant_) {
+ return Status::OK();
+ }
#if defined(ORT_NEURAL_SPEED)
if (!all_constant_) {
@@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep
Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
-
const Tensor* a = ctx->Input(0);
const auto* a_data = a->Data();
@@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
#endif // defined(ORT_NEURAL_SPEED)
const Tensor* scales = ctx->Input(2);
- const Tensor* zero_points = ctx->Input(3);
+ const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr;
+ const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr;
+
const auto* scales_data = scales->Data();
- const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
+ const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
TensorShape b_shape({static_cast(N_), static_cast(K_)});
+ const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data();
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
@@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast(helper.K());
const size_t lda = helper.Lda(false);
- const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
- [](size_t offset) { return offset == 0; });
+ const bool has_single_b_matrix =
+ (!act_order_) && (!zero_point_is_not_quant_) &&
+ std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });
if (has_single_b_matrix) {
const auto compute_type = static_cast(accuracy_level_);
@@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const uint8_t* b_data = b->Data();
const size_t ldb = helper.Ldb(true);
-
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
- // dequantize b, only 4b quantization is supported for now
- MlasDequantizeBlockwise(
- tmp_b_data_ptr.get(), // dequantized output
- b_data, // quantized input
- scales_data, // quantization scales
- zero_points_data, // quantization zero points
- static_cast(block_size_), // quantization block size
- column_wise_quant_, // columnwise quantization or row-wise
- static_cast(K_), // number of rows in quantized input
- static_cast(N_), // number of columns in quantized input
- thread_pool);
-
+ if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) {
+ // dequantize b, only 4b quantization is supported for now
+ MlasDequantizeBlockwise(
+ tmp_b_data_ptr.get(), // dequantized output
+ b_data, // quantized input
+ scales_data, // quantization scales
+ static_cast(zero_points_data), // quantization zero points
+ static_cast(block_size_), // quantization block size
+ column_wise_quant_, // columnwise quantization or row-wise
+ static_cast(K_), // number of rows in quantized input
+ static_cast(N_), // number of columns in quantized input
+ thread_pool);
+ } else {
+ ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now");
+ // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
+ if ((zero_points && zero_points->IsDataType())) {
+ DequantizeBlockwise(
+ tmp_b_data_ptr.get(), // dequantized output
+ b_data, // quantized input
+ scales_data, // quantization scales
+ static_cast(zero_points_data), // quantization zero points
+ reorder_idx_data,
+ static_cast(block_size_), // quantization block size
+ column_wise_quant_, // columnwise quantization or row-wise
+ static_cast(K_), // number of rows in quantized input
+ static_cast(N_), // number of columns in quantized input
+ thread_pool);
+ } else {
+ DequantizeBlockwise(
+ tmp_b_data_ptr.get(), // dequantized output
+ b_data, // quantized input
+ scales_data, // quantization scales
+ static_cast(zero_points_data), // quantization zero points
+ reorder_idx_data,
+ static_cast(block_size_), // quantization block size
+ column_wise_quant_, // columnwise quantization or row-wise
+ static_cast(K_), // number of rows in quantized input
+ static_cast(N_), // number of columns in quantized input
+ thread_pool);
+ }
+ }
#if 0 // for debug
auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
@@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX(
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()})
+ .TypeConstraint("T4", DataTypeImpl::GetTensorType()),
MatMulNBits);
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
new file mode 100644
index 0000000000000..f92e59e990ba5
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
@@ -0,0 +1,108 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include "core/common/common.h"
+#include "core/framework/float16.h"
+#include "core/providers/common.h"
+#include "core/platform/threadpool.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+void Dequantize4BitsKernelReOrder(
+ T* output, const uint8_t* quant_data, const T* scale_data,
+ const zeroT* zero_points, const int32_t* reorder_idx, int block_size,
+ int groups_per_threadblock, int total_groups, int out_rows, int out_cols,
+ int blockIdx_x, int threadIdx_x) {
+ const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size);
+ if (group_id >= total_groups) {
+ return;
+ }
+ const int scales_shape_x = (out_cols + block_size - 1) / block_size;
+ const int zero_point_shape_x = (scales_shape_x + 1) / 2;
+
+ int n_idx = group_id / scales_shape_x;
+ int kb_idx = group_id % scales_shape_x;
+ int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1));
+
+ const int out_x = element_offset % (scales_shape_x * block_size);
+ const int out_y = element_offset / (scales_shape_x * block_size);
+ if (out_y >= out_rows || out_x >= out_cols) {
+ return;
+ }
+ T* output_i = output + out_y * out_cols + out_x;
+ uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2));
+ const int remain_x = std::min(8, out_cols - out_x);
+ for (int i = 0; i < remain_x; i++) {
+ int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx;
+ T scale = *(scale_data + n_idx * scales_shape_x + rid);
+ float zp_f = 8;
+ if (zero_points) {
+ if constexpr (std::is_same_v) {
+ zp_f = *(zero_points + n_idx * scales_shape_x + rid);
+ } else {
+ uint8_t zp = 8;
+ zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
+ zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
+ }
+ }
+
+ if constexpr (std::is_same_v) {
+ T zp_adjust = -scale * MLFloat16(zp_f);
+ output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+ } else {
+ T zp_adjust = -scale * zp_f;
+ output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+ }
+ }
+}
+
+template
+void DequantizeBlockwise(
+ inputT* output, // dequantized output
+ const uint8_t* quant_data, // quantized input
+ const inputT* scales_data, // quantization scales
+ const zeroT* zero_points, // quantization zero points
+ const int32_t* reorder_idx, // reorder_idx for groupwise quantization
+ int32_t block_size, // quantization block size
+ bool, // columnwise quantization or row-wise
+ int32_t K, // number of rows in quantized input
+ int32_t N, // number of columns in quantized input
+ onnxruntime::concurrency::ThreadPool* pool) {
+ auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+ constexpr int element_per_thread = 8;
+ int groups_per_threadblock = 256 * element_per_thread / block_size;
+ int groups_per_K = ceildiv(K, block_size);
+ int total_groups = N * groups_per_K; // total elemenets in quant_data
+ int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock));
+ concurrency::ThreadPool::TrySimpleParallelFor(
+ pool, static_cast(blocks_per_grid),
+ [&](std::ptrdiff_t block_id) {
+ for (int j = 0; j < 256; j++) {
+ Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
+ reorder_idx, block_size, groups_per_threadblock,
+ total_groups, N, K, static_cast(block_id), j);
+ }
+ });
+}
+
+template void DequantizeBlockwise(
+ float* output, const uint8_t* quant_data, const float* scales_data,
+ const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
+ bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
+
+template void DequantizeBlockwise(
+ float* output, const uint8_t* quant_data, const float* scales_data,
+ const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
+ bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
new file mode 100644
index 0000000000000..5061ac5c800a6
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "core/providers/common.h"
+#include "core/platform/threadpool.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+void DequantizeBlockwise(
+ inputT* output, // dequantized output
+ const uint8_t* quant_data, // quantized input
+ const inputT* scales_data, // quantization scales
+ const zeroT* zero_points, // quantization zero points
+ const int32_t* reorder_idx, // quantization zero points
+ int32_t block_size, // quantization block size
+ bool, // columnwise quantization or row-wise
+ int32_t K, // number of rows in quantized input
+ int32_t N, // number of columns in quantized input
+ onnxruntime::concurrency::ThreadPool* thread_pool);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
index 6b66f1d84e221..cd6593352008b 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
@@ -2,10 +2,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include
#include
#include
#include
#include
+#include
#include
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
@@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f
}
template
-__global__ void Dequantize4BitsKernel(
+__global__ void Dequantize4BitsKernelReOrder(
T* output,
const uint8_t* quant_data,
const T* scale_data,
const uint8_t* zero_points,
+ const int32_t* reorder_idx,
int block_size,
- int blocks_per_K,
- int blocks_per_threadblock,
- int total_blks,
- int shift) {
- int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
- if (block_id >= total_blks) {
+ int groups_per_K,
+ int groups_per_threadblock,
+ int total_groups) {
+ int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
+ if (group_id >= total_groups) {
return;
}
- int n_idx = block_id / blocks_per_K;
- int kb_idx = block_id % blocks_per_K;
- int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
+ // T __shared__ zero_points_after_reorder[];//K
+ // T __shared__ scales_after_reorder[]; // K
+ // const int num_r_per_thread = k / 256;
+
+ const int zero_point_shape_x = (groups_per_K + 1) / 2;
+ const int scales_shape_x = groups_per_K;
+ int n_idx = group_id / scales_shape_x;
+ int kb_idx = group_id % scales_shape_x;
+ int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
+ T* output_i = output + element_offset;
+ uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2));
+ for (int i = 0; i < 8; i++) {
+ int32_t rid = reorder_idx[kb_idx * block_size + i];
+ T scale = *(scale_data + n_idx * scales_shape_x + rid);
+ uint8_t zp = 8;
+ if (zero_points) {
+ zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
+ zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
+ }
+
+ if constexpr (std::is_same_v) {
+ T zp_adjust = -scale * __short2half_rn(zp);
+ output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+ } else {
+ T zp_adjust = -scale * T(zp);
+ output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+ }
+ }
+}
+
+template
+__global__ void Dequantize4BitsKernel(
+ T* output,
+ const uint8_t* quant_data,
+ const T* scale_data,
+ const ZeroT* zero_points,
+ int block_size,
+ int groups_per_K,
+ int groups_per_threadblock,
+ int total_groups) {
+ int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
+ if (block_id >= total_groups) {
+ return;
+ }
+ int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
- uint8_t zp = 8;
- if (zero_points) {
- zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2];
- zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
+ T zero_point_value;
+ if constexpr (std::is_same_v) {
+ const int scales_shape_x = groups_per_K;
+ const int zero_point_shape_x = (groups_per_K + 1) / 2;
+ int kb_idx = block_id % scales_shape_x;
+ int n_idx = block_id / scales_shape_x;
+ uint8_t zp = 8;
+ if (zero_points) {
+ zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2];
+ zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
+ }
+ zero_point_value = static_cast(zp);
+ } else {
+ zero_point_value = zero_points? *(zero_points + block_id):static_cast(8);
}
output = output + element_offset;
- DequantizeEightElements(quant_value, scale, static_cast(zp), output);
+ DequantizeEightElements(quant_value, scale, zero_point_value, output);
}
-template
+template
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
- const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2]
+ const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2]
+ const int32_t* reorder_idx,
int k,
int n,
int block_size,
@@ -98,47 +153,79 @@ Status Dequantize4Bits(
// k is padded and equal to block_per_K * block_size
ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
constexpr int element_per_thread = 8;
- int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
- int blocks_per_K = k / block_size;
- int total_blks = n * blocks_per_K;
- int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
- int shift = static_cast(log2f(float(block_size)));
-
- Dequantize4BitsKernel<<>>(
- output,
- quant_data,
- scales_data,
- zero_points,
- block_size,
- blocks_per_K,
- blocks_per_threadblock,
- total_blks,
- shift);
+ int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
+ int groups_per_K = k / block_size;
+ int total_groups = n * groups_per_K; // total elemenets in quant_data
+ int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock));
+ if (!reorder_idx) {
+ Dequantize4BitsKernel<<>>(
+ output,
+ quant_data,
+ scales_data,
+ zero_points,
+ block_size,
+ groups_per_K,
+ groups_per_threadblock,
+ total_groups);
+ } else {
+ // static_assert(std::is_same_v, "ZeroT must be uint8_t");
+ Dequantize4BitsKernelReOrder<<>>(
+ output,
+ quant_data,
+ scales_data,
+ (const uint8_t*)zero_points,
+ reorder_idx,
+ block_size,
+ groups_per_K,
+ groups_per_threadblock,
+ total_groups);
+ }
return Status::OK();
}
-template Status Dequantize4Bits(
+template Status Dequantize4Bits(
float* output,
const uint8_t* quant_data,
const float* scales_data,
const uint8_t* zero_points,
+ const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);
-template Status Dequantize4Bits(
+template Status Dequantize4Bits(
half* output,
const uint8_t* quant_data,
const half* scales_data,
const uint8_t* zero_points,
+ const int32_t* reorder_idx,
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream);
+template Status Dequantize4Bits(
+ float* output,
+ const uint8_t* quant_data,
+ const float* scales_data,
+ const float* zero_points,
+ const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);
-
+template Status Dequantize4Bits(
+ half* output,
+ const uint8_t* quant_data,
+ const half* scales_data,
+ const half* zero_points,
+ const int32_t* reorder_idx,
+ int k,
+ int n,
+ int block_size,
+ cudaStream_t stream);
///////////////////////////////////////////////////////////////////////////////
// A more general block-wise dequantization implementation that supports
// different block sizes and block orientations (row-wise/column-wise).
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
index f9c09c55fd893..580b5087f3fa3 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
@@ -7,18 +7,18 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
-template
+template
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
- const uint8_t* zero_points,
+ const ZeroT* zero_points,
+ const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);
-
/**
* @brief Dequantize a block-wise quantized matrix, and store the result in a
* column major matrix for use in subsequent GEMM. This implementation supports
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
index 015df70c8ec3c..1cec6f6a12f1c 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
@@ -1,15 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-//
-// This module define MatMulFp32Q4 operator, it is basically
-// matmul float32 with right hand side being a 2-D matrix
-// pre-packed and block-compacted into int4
-//
-
-#include "core/common/safeint.h"
-#include "core/providers/cuda/cuda_kernel.h"
-#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "contrib_ops/cuda/quantization/matmul_nbits.h"
+
+#include
+
+#include "core/common/status.h"
+#include "core/framework/float16.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "matmul_nbits.cuh"
#include "dequantize_blockwise.cuh"
@@ -19,40 +16,19 @@ namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
-template
-class MatMulNBits final : public CudaKernel {
- public:
- MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
- ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
- ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
- ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
- ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_));
- ORT_ENFORCE(nbits_ == 4,
- "Only 4b quantization is supported for MatMulNBits op,"
- " additional bits support is planned.");
- }
-
- Status ComputeInternal(OpKernelContext* context) const override;
-
- private:
- int64_t K_;
- int64_t N_;
- int64_t block_size_;
- int64_t nbits_;
- bool column_wise_quant_blk_{true};
-};
-
template
Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input(0);
const Tensor* b = ctx->Input(1);
const Tensor* scales = ctx->Input(2);
const Tensor* zero_points = ctx->Input(3);
+ const Tensor* reorder_idx = ctx->Input(4);
const auto* a_data = a->Data();
const uint8_t* blob_data = b->Data();
const auto* scales_data = scales->Data();
- const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
+ const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
+ const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data();
typedef typename ToCudaType::MappedType CudaT;
@@ -67,77 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();
- bool is_4bit_done = TryMatMul4Bits(
- reinterpret_cast(Y->MutableData()),
- reinterpret_cast(a_data),
- blob_data,
- reinterpret_cast(scales_data),
- zero_points_data,
- SafeInt(helper.M()),
- SafeInt(helper.N()),
- SafeInt(helper.K()),
- SafeInt(block_size_),
- SafeInt(GetDeviceProp().sharedMemPerBlock),
- static_cast(ctx->GetComputeStream()->GetHandle()));
- if (!is_4bit_done) {
- int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
- IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream());
- auto* b_data = b_data_ptr.get();
- if (column_wise_quant_blk_) {
- // column-wise block
+ bool is_4bit_done = (reorder_idx_data == nullptr) &&
+ (!zero_points || !zero_points->IsDataType()) &&
+ TryMatMul4Bits(
+ reinterpret_cast(Y->MutableData()),
+ reinterpret_cast(a_data),
+ blob_data,
+ reinterpret_cast(scales_data),
+ static_cast(zero_points_data),
+ SafeInt(helper.M()),
+ SafeInt(helper.N()),
+ SafeInt(helper.K()),
+ SafeInt(block_size_),
+ SafeInt(GetDeviceProp().sharedMemPerBlock),
+ static_cast(ctx->GetComputeStream()->GetHandle()));
+
+ if (is_4bit_done) {
+ return Status::OK();
+ }
+
+ int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
+ IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream());
+ auto* b_data = b_data_ptr.get();
+ if (column_wise_quant_blk_) {
+ if (reorder_idx) {
+ ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]");
+ }
+ // column-wise block
+ if ((zero_points && zero_points->IsDataType())) {
ORT_RETURN_IF_ERROR(Dequantize4Bits(
reinterpret_cast(b_data),
blob_data,
reinterpret_cast(scales_data),
- zero_points_data,
+ (const CudaT*)zero_points_data,
+ reorder_idx_data,
SafeInt(K_padded),
SafeInt(N_),
SafeInt(block_size_),
static_cast(ctx->GetComputeStream()->GetHandle())));
} else {
- // row-wise block
- K_padded = K_;
-
- ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
+ ORT_RETURN_IF_ERROR(Dequantize4Bits(
reinterpret_cast(b_data),
blob_data,
reinterpret_cast(scales_data),
- zero_points_data,
- SafeInt(block_size_),
- column_wise_quant_blk_,
- SafeInt(K_),
+ (const uint8_t*)zero_points_data,
+ reorder_idx_data,
+ SafeInt(K_padded),
SafeInt(N_),
+ SafeInt(block_size_),
static_cast(ctx->GetComputeStream()->GetHandle())));
}
+ } else {
+ // row-wise block
+ K_padded = K_;
+
+ ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
+ reinterpret_cast(b_data),
+ blob_data,
+ reinterpret_cast(scales_data),
+ (const uint8_t*)zero_points_data,
+ SafeInt(block_size_),
+ column_wise_quant_blk_,
+ SafeInt(K_),
+ SafeInt(N_),
+ static_cast(ctx->GetComputeStream()->GetHandle())));
+ }
#if 0
- cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()));
- T* b_data_cpu = new T[K_ * N_];
- cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
- delete[] b_data_cpu;
+cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()));
+T* b_data_cpu = new T[K_ * N_];
+cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
+delete[] b_data_cpu;
#endif
- const CudaT alpha = ToCudaType::FromFloat(1.f);
- const CudaT zero = ToCudaType::FromFloat(0.f);
-
- if (helper.OutputOffsets().size() == 1) {
- CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
- GetCublasHandle(ctx),
- CUBLAS_OP_T,
- CUBLAS_OP_N,
- SafeInt(helper.N()),
- SafeInt(helper.M()),
- SafeInt(helper.K()),
- &alpha,
- reinterpret_cast(b_data),
- SafeInt(K_padded),
- reinterpret_cast(a_data),
- helper.Lda(transa),
- &zero,
- reinterpret_cast(Y->MutableData()),
- helper.Ldc(),
- GetDeviceProp(),
- UseTF32()));
- }
+ const CudaT alpha = ToCudaType::FromFloat(1.f);
+ const CudaT zero = ToCudaType::FromFloat(0.f);
+
+ if (helper.OutputOffsets().size() == 1) {
+ CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
+ GetCublasHandle(ctx),
+ CUBLAS_OP_T,
+ CUBLAS_OP_N,
+ SafeInt(helper.N()),
+ SafeInt(helper.M()),
+ SafeInt(helper.K()),
+ &alpha,
+ reinterpret_cast(b_data),
+ SafeInt(K_padded),
+ reinterpret_cast(a_data),
+ helper.Lda(transa),
+ &zero,
+ reinterpret_cast(Y->MutableData()),
+ helper.Ldc(),
+ GetDeviceProp(),
+ UseTF32()));
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
new file mode 100644
index 0000000000000..f5c2c6c4e4fdf
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+//
+// This module define MatMulNBits operator, it is basically
+// matmul float with right hand side being a 2-D matrix
+// pre-packed and block-compacted into int4
+//
+#pragma once
+#include "core/common/safeint.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+using namespace onnxruntime::cuda;
+
+template
+class MatMulNBits final : public CudaKernel {
+ public:
+ MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_));
+ }
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t nbits_;
+ bool column_wise_quant_blk_{true};
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index e33ce20737f80..f06a3785f362d 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3343,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.
-Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
-- n_blocks_per_col = (K + block_size - 1) / block_size
-- blob_size = block_size / 8 * bits
-
- For a block blob. It is stored in format:
- struct Blob {
- uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
- uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
- uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
- }
+ Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+ - n_blocks_per_col = (K + block_size - 1) / block_size
+ - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
+ For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
+ - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
+ 4bit example:
+ |.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
+ - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
+ 3bit example:
+ |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
+ The last uint_8 may have some bits unused.
-Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
-Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- - [(N * n_blocks_per_col + 1) / 2] if bits <=4
- - [N * n_blocks_per_col] if bits > 4
+Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
+Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
+ - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
+ If zero_points has same type as A, it's not packed and has the same shape as Scales.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
@@ -3377,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
"type T1.",
AttributeProto::INT, static_cast(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
- .Input(1, "B", "1-dimensional data blob", "T2")
+ .Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
- .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional)
+ .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
+ .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
- .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+ .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.")
+ .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")
+ .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
index eb7bbec997d59..a1916e806c5c0 100644
--- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
+++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
@@ -65,7 +65,7 @@ def __init__(
self,
calibration_data_reader: CalibrationDataReader,
percdamp=0.01,
- blocksize=128,
+ block_size=128,
actorder=False,
mse=False,
perchannel=True,
@@ -79,7 +79,7 @@ def __init__(
a calibration data reader. It enumerates calibration data and generates inputs for the original model.
percdamp:
percent of the average Hessian diagonal to use for dampening.
- blocksize (int, optional):
+ block_size (int, optional):
channel number in one block to execute a GPTQ quantization iteration.
actorder (bool, optional):
whether rearrange Hessian matrix considering the diag's value.
@@ -93,42 +93,285 @@ def __init__(
)
self.calibration_data_reader = calibration_data_reader
self.percdamp = percdamp
- self.blocksize = blocksize
+ self.block_size = block_size
self.actorder = actorder
self.mse = mse
self.perchannel = perchannel
-class MatMul4BitsQuantizer:
- """Perform 4b quantization of constant MatMul weights"""
+class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
+ def __init__(
+ self,
+ block_size=128,
+ bits=4,
+ axis=1,
+ ):
+ """
+ This is a class for HQQ algorithm Weight Only Quant Configuration.
+ HQQ algorithm quant weight without needing calibrate data.
+
+ Args:
+ block_size (int, optional):
+ channel number in one block to execute a GPTQ quantization iteration.
+ bits (int, optional):
+ how many bits to represent weight.
+ axis (int, optional):
+ 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
+ """
+ super().__init__(
+ algorithm="HQQ",
+ )
+ self.block_size = block_size
+ self.bits = bits
+ self.axis = axis
+
+class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
def __init__(
self,
- model: ModelProto | str,
- block_size: int,
- is_symmetric: bool,
+ block_size: int = 128,
+ is_symmetric: bool = False,
accuracy_level: int | None = None,
- nodes_to_exclude=None,
- algo_config: WeightOnlyQuantConfig = None,
):
- if nodes_to_exclude is None:
- nodes_to_exclude = []
- self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
- self.model_path = model if isinstance(model, str) else None
+ super().__init__(algorithm="DEFAULT")
self.block_size = block_size
self.is_symmetric = is_symmetric
+ self.bits = 4
self.accuracy_level = accuracy_level
- self.nodes_to_exclude = set(nodes_to_exclude)
- self.algo_config = algo_config
+
+
+def is_divisible(val1, val2):
+ return int(val2 * np.ceil(val1 / val2)) == val1
+
+
+class HQQWeightOnlyQuantizer:
+ def __init__(
+ self,
+ config: HQQWeightOnlyQuantConfig,
+ ):
+ self.config = config
+
+ # Proximal solver || weight - dequantize(quantize(weight))||_p^p
+ @staticmethod
+ def optimize_weights(
+ tensor,
+ scale,
+ zero,
+ min_max: list[int],
+ axis: int = 0,
+ opt_params: dict = None, # noqa: RUF013
+ verbose=False,
+ ):
+ import torch
+
+ opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
+ lp_norm, beta, kappa, iters = (
+ opt_params["lp_norm"],
+ opt_params["beta"],
+ opt_params["kappa"],
+ opt_params["iters"],
+ )
+
+ dtype = torch.float16 if tensor.is_cuda else torch.float32
+ w_f = tensor.to(dtype)
+ scale = scale.to(dtype)
+ zero = zero.to(dtype)
+
+ if lp_norm == 1:
+
+ def shrink_op(x, beta):
+ return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
+
+ else:
+
+ def shrink_op(x, beta, p=lp_norm):
+ return torch.sign(x) * torch.nn.functional.relu(
+ torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
+ )
+
+ best_error = 1e4
+ for i in range(iters):
+ w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1])
+ w_r = (w_q - zero) / scale
+ w_e = shrink_op(w_f - w_r, beta)
+ zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True)
+ beta *= kappa
+
+ current_error = float(torch.abs(w_f - w_r).mean())
+ if verbose:
+ print(i, np.round(current_error, 6))
+ if current_error < best_error:
+ best_error = current_error
+ else:
+ break
+
+ del w_f, w_q, w_r, w_e
+
+ return scale, zero
@staticmethod
- def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
- for gid in range(len(graph_path) - 1, -1, -1):
- graph = graph_path[gid]
- for tensor in graph.initializer:
- if tensor.name == name:
- return tensor, graph
- return None, None
+ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
+ if pack_tensor.shape[0] == ori_int_tensor.shape[0]:
+ ori_int_tensor = ori_int_tensor.T
+ pack_tensor = pack_tensor.T
+ if bits in [2, 4, 8]:
+ compress_ratio = pack_tensor.element_size() * 8 // bits
+ for j in range(0, compress_ratio):
+ pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+ # from Official implementation of Half-Quadratic Quantization (HQQ)
+ def quantize_internal(
+ self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
+ ):
+ import torch
+
+ weight = tensor.float()
+ ori_shape = weight.shape
+
+ pad_len = (group_size - ori_shape[axis] % group_size) % group_size
+ if axis == 1:
+ weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0)
+ else:
+ weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0)
+ shape = weight.shape
+
+ # Reshape for grouping
+ if (group_size is not None) and channel_wise:
+ weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1])
+
+ # Get min/max values
+ if channel_wise is False:
+ _min, _max = weight.min(), weight.max()
+ optimize = False
+ else:
+ _min = weight.min(axis=axis, keepdim=True)[0]
+ _max = weight.max(axis=axis, keepdim=True)[0]
+
+ max_v = 2**bits - 1
+ min_v = 0
+ min_max = [min_v, max_v]
+
+ # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on.
+ # clamp to avoid half-precision problems
+ scale = (max_v / (_max - _min)).clamp(max=2e4)
+ #!!!!!!!!!!!!!!!
+ min_max_axis = _max - _min
+ if (min_max_axis == 0).sum().item() > 0:
+ min_max_axis[min_max_axis == 0] = max_v
+ scale = (max_v / min_max_axis).clamp(max=2e4)
+ zero = -_min * scale
+
+ if round_zero:
+ zero = torch.round(zero)
+
+ # Fine-tune weights
+ if optimize:
+ scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis)
+
+ # Quantize
+ # Necessary for fake quantization backprop
+ w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1])
+ w_q = w_q.reshape(shape).int()
+
+ scale = 1.0 / scale
+ if axis == 1:
+ scale = scale.reshape(shape[0], -1)
+ zero = zero.reshape(shape[0], -1)
+ else:
+ scale = scale.reshape(-1, shape[-1])
+ zero = zero.reshape(-1, shape[-1])
+ # cleanup
+ del weight, _min, _max
+
+ return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
+
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]):
+ """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
+ if node.op_type != "MatMul":
+ return node # only care about MatMul for now
+ import torch
+
+ logger.info(f"start to quantize {node.name} ...")
+ inputB = node.input[1] # noqa: N806
+ b_pb, bs_graph = get_initializer(inputB, graph_stack)
+ if b_pb is None:
+ logger.info("MatMul doesn't have const weight. Skip to quantize")
+ return node # only care about constant weight
+
+ b_array = onnx.numpy_helper.to_array(b_pb)
+ if len(b_array.shape) != 2:
+ logger.info("MatMul weight is not 2D. Skip to quantize")
+ return node # can only process 2-D matrix
+ b_array_torch = torch.from_numpy(b_array)
+ if torch.cuda.is_available():
+ b_array_torch = b_array_torch.cuda()
+ quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
+ b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size
+ )
+ quant_weight_torch = quant_weight_torch.contiguous()
+ scales_torch = scales_torch.contiguous()
+ zero_points_torch = zero_points_torch.contiguous()
+
+ packed_torch = torch.zeros(
+ (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2),
+ dtype=torch.uint8,
+ device=quant_weight_torch.device,
+ )
+ self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits)
+ scales = scales_torch.cpu().numpy()
+ zero_points = zero_points_torch.cpu().numpy()
+ b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
+ b_quant.name = b_pb.name + "_Q4"
+ for input in bs_graph.input:
+ if input.name == inputB:
+ bs_graph.input.remove(input)
+ break
+
+ scales_tensor = onnx.numpy_helper.from_array(scales)
+ scales_tensor.name = b_pb.name + "_scales"
+ bs_graph.initializer.extend([b_quant, scales_tensor])
+
+ input_names = [node.input[0], b_quant.name, scales_tensor.name]
+ zp_tensor = onnx.numpy_helper.from_array(zero_points)
+ zp_tensor.name = b_pb.name + "_zero_points"
+ bs_graph.initializer.extend([zp_tensor])
+ input_names.append(zp_tensor.name)
+
+ kwargs = {}
+ rows, cols = b_array.shape
+ kwargs["K"] = rows
+ kwargs["N"] = cols
+ kwargs["bits"] = self.config.bits
+ kwargs["block_size"] = self.config.block_size
+
+ matmul_q4_node = onnx.helper.make_node(
+ "MatMulNBits",
+ inputs=input_names,
+ outputs=[node.output[0]],
+ name=node.name + "_Q4" if node.name else "",
+ domain="com.microsoft",
+ **kwargs,
+ )
+
+ logger.info(f"complete quantization of {node.name} ...")
+
+ return matmul_q4_node
+
+
+def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
+ for gid in range(len(graph_path) - 1, -1, -1):
+ graph = graph_path[gid]
+ for tensor in graph.initializer:
+ if tensor.name == name:
+ return tensor, graph
+ return None, None
+
+
+class DefaultWeightOnlyQuantizer:
+ def __init__(self, config: DefaultWeightOnlyQuantConfig):
+ self.config = config
def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
"""4b quantize fp32 weight to a blob"""
@@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
raise ValueError("Current int4 block quantization only supports 2D tensors!")
rows, cols = fp32weight.shape
- block_size = self.block_size
+ block_size = self.config.block_size
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
@@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
- quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric)
+ quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric)
return (packed, scales, zero_point)
- def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
+ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
if node.op_type != "MatMul":
return node # only care about MatMul for now
logger.info(f"start to quantize {node.name} ...")
- if node.name in self.nodes_to_exclude:
- logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
- return node
-
inputB = node.input[1] # noqa: N806
- B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806
+ B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806
if B is None:
logger.info("MatMul doesn't have const weight. Skip to quantize")
return node # only care about constant weight
@@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
Bs_graph.initializer.extend([B_quant, scales_tensor])
input_names = [node.input[0], B_quant.name, scales_tensor.name]
- if not self.is_symmetric:
+ if not self.config.is_symmetric:
zp_tensor = onnx.numpy_helper.from_array(zero_points)
zp_tensor.name = B.name + "_zero_points"
Bs_graph.initializer.extend([zp_tensor])
@@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
kwargs["K"] = rows
kwargs["N"] = cols
kwargs["bits"] = 4
- kwargs["block_size"] = self.block_size
- if self.accuracy_level is not None:
+ kwargs["block_size"] = self.config.block_size
+ if self.config.accuracy_level is not None:
kwargs["accuracy_level"] = self.accuracy_level
matmul_q4_node = onnx.helper.make_node(
@@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
return matmul_q4_node
+
+class MatMul4BitsQuantizer:
+ """Perform 4b quantization of constant MatMul weights"""
+
+ def __init__(
+ self,
+ model: ModelProto | str,
+ block_size: int = 128,
+ is_symmetric: bool = False,
+ accuracy_level: int | None = None,
+ nodes_to_exclude=None,
+ algo_config: WeightOnlyQuantConfig = None,
+ ):
+ if nodes_to_exclude is None:
+ nodes_to_exclude = []
+ self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
+ self.model_path = model if isinstance(model, str) else None
+ self.block_size = block_size
+ self.is_symmetric = is_symmetric
+ self.accuracy_level = accuracy_level
+ self.nodes_to_exclude = set(nodes_to_exclude)
+ self.node_quantizer = None
+ if algo_config is None:
+ algo_config = DefaultWeightOnlyQuantConfig(
+ block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level
+ )
+ self.algo_config = algo_config
+ if algo_config.algorithm == "HQQ":
+ self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
+ elif algo_config.algorithm == "DEFAULT":
+ self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config)
+
def _process_subgraph(self, graph_stack: list[GraphProto]):
new_nodes = []
graph = graph_stack[-1]
@@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]):
node = onnx.helper.make_node( # noqa: PLW2901
node.op_type, node.input, node.output, name=node.name, **kwargs
)
-
- new_nodes.append(self._q4_matmul_node_weight(node, graph_stack))
+ out_node = None
+ if node.name in self.nodes_to_exclude:
+ logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
+ out_node = node
+ elif self.algo_config is not None and self.algo_config.algorithm == "HQQ":
+ out_node = self.node_quantizer.quantize(node, graph_stack)
+ else:
+ out_node = self.node_quantizer.quantize(node, graph_stack)
+ new_nodes.append(out_node)
graph.ClearField("node")
graph.node.extend(new_nodes)
@@ -300,7 +578,7 @@ def inc_dataloader():
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
kwargs["percdamp"] = self.algo_config.percdamp
- kwargs["blocksize"] = self.algo_config.blocksize
+ kwargs["blocksize"] = self.algo_config.block_size
kwargs["actorder"] = self.algo_config.actorder
kwargs["mse"] = self.algo_config.mse
kwargs["perchannel"] = self.algo_config.perchannel
@@ -316,7 +594,7 @@ def inc_dataloader():
logger.info(f"complete quantization of model with {algorithm} algorithm.")
def process(self):
- if self.algo_config is None:
+ if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()
@@ -327,7 +605,6 @@ def process(self):
has_ms_domain = True
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
-
self._process_subgraph(graph_stack)
self.model.clean_initializers()
else:
@@ -366,6 +643,14 @@ def parse_args():
parser.add_argument("--input_model", required=True, help="Path to the input model file")
parser.add_argument("--output_model", required=True, help="Path to the output model file")
parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization")
+ parser.add_argument(
+ "--quant_method",
+ default="default",
+ type=str,
+ choices=["default", "hqq"],
+ help="the algorithm used to quantize weight",
+ )
+ parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
parser.add_argument(
"--symmetric",
required=False,
@@ -411,12 +696,24 @@ def parse_args():
raise Exception(f"file {output_model_path} already exists")
model = onnx.load(input_model_path)
+ if args.quant_method == "hqq":
+ quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits)
+ elif args.quant_method == "default":
+ quant_config = DefaultWeightOnlyQuantConfig(
+ block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level
+ )
+ elif args.quant_method == "rtn":
+ quant_config = RTNWeightOnlyQuantConfig()
+ elif args.quant_method == "gptq":
+ quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size)
+ else:
+ raise ValueError(f"Unsupported quantization method: {args.quant_method}")
+
quant = MatMul4BitsQuantizer(
model=model,
- block_size=args.block_size,
- is_symmetric=args.symmetric,
accuracy_level=args.accuracy_level,
nodes_to_exclude=args.nodes_to_exclude,
+ algo_config=quant_config,
)
quant.process()
quant.model.save_model_to_file(output_model_path, True)
diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
index 2ad20eafc2ef1..d294fd4e2b0e0 100644
--- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
+++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#ifndef ORT_MINIMAL_BUILD
+#include
#include "core/common/span_utils.h"
#include "core/framework/tensor.h"
@@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector& raw_vals,
}
void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
- bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) {
+ bool has_zeropoint, bool use_float16, bool has_g_idx = false,
+ bool zp_is_4bit = true, float fp16_abs_error = 0.02f) {
+ zp_is_4bit = zp_is_4bit | has_g_idx;
RandomValueGenerator random{1234};
std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f));
std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f));
@@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
test.AddAttribute("block_size", block_size);
test.AddAttribute("bits", QBits);
test.AddAttribute("accuracy_level", accuracy_level);
+ auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; };
+
if (use_float16) {
test.AddInput("A", {M, K}, ToFloat16(input0_vals), false);
test.AddInput("B", {q_cols, q_rows}, input1_vals, true);
test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true);
if (has_zeropoint) {
- test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true);
+ if (zp_is_4bit) {
+ test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true);
+ } else {
+ std::vector zp_f;
+ zp_f.reserve(q_zp_size_in_bytes * 2);
+ for (size_t i = 0; i < zp.size(); i++) {
+ zp_f.push_back(static_cast(zp[i] & 0xf));
+ zp_f.push_back(static_cast((zp[i] >> 4) & 0xf));
+ }
+ size_t ind = zp_f.size() - 1;
+ while (zp_f.size() != q_scale_size) {
+ zp_f.erase(zp_f.begin() + ind);
+ ind -= q_scale_size / N + 1;
+ }
+
+ test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true);
+ }
+ } else {
+ test.AddInput("", {0}, {});
+ }
+ if (has_g_idx) {
+ int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size);
+ std::vector g_idx(K_pad);
+ for (int64_t i = 0; i < K_pad; i++) {
+ g_idx[i] = gsl::narrow(i / block_size);
+ }
+ test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true);
}
test.AddOutput("Y", {M, N}, ToFloat16(expected_vals));
@@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
test.AddInput("B", {q_cols, q_rows}, input1_vals, true);
test.AddInput("scales", {static_cast(q_scale_size)}, scales, true);
if (has_zeropoint) {
- test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true);
- }
+ if (zp_is_4bit) {
+ test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true);
+ } else {
+ std::vector zp_f;
+ zp_f.reserve(q_zp_size_in_bytes * 2);
+ for (size_t i = 0; i < zp.size(); i++) {
+ zp_f.push_back(static_cast(zp[i] & 0xf));
+ zp_f.push_back(static_cast((zp[i] >> 4) & 0xf));
+ }
+ size_t ind = zp_f.size() - 1;
+ while (zp_f.size() != q_scale_size) {
+ zp_f.erase(zp_f.begin() + ind);
+ ind -= q_scale_size / N + 1;
+ }
+ test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true);
+ }
+ } else {
+ test.AddInput("", {0}, {});
+ }
+ if (has_g_idx) {
+ int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size);
+ std::vector g_idx(K_pad);
+ for (int64_t i = 0; i < K_pad; i++) {
+ g_idx[i] = gsl::narrow(i / block_size);
+ }
+ test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true);
+ }
test.AddOutput("Y", {M, N}, expected_vals);
if (accuracy_level == 4) {
test.SetOutputAbsErr("Y", 0.1f);
@@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) {
for (auto accuracy_level : {0}) {
RunTest(M, N, K, block_size, accuracy_level, false, false);
RunTest(M, N, K, block_size, accuracy_level, true, false);
+ RunTest(M, N, K, block_size, accuracy_level, false, false, true);
+ RunTest(M, N, K, block_size, accuracy_level, true, false, false, false);
}
#endif
}
@@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) {
for (auto N : {1, 2, 32, 288}) {
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
for (auto block_size : {16, 32, 64, 128}) {
- RunTest(M, N, K, block_size, 0, false, true);
- RunTest(M, N, K, block_size, 0, true, true);
+ for (auto has_gidx : {true, false}) {
+ RunTest(M, N, K, block_size, 0, false, true, has_gidx);
+ RunTest(M, N, K, block_size, 0, true, true, has_gidx, false);
+ }
}
}
}
@@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) {
TEST(MatMulNBits, Float16Large) {
for (auto block_size : {16, 32, 64, 128}) {
for (auto symmetric : {false, true}) {
- RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f);
- RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f);
- RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f);
+ RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f);
+ RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f);
+ RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f);
}
}
}
diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py
index c1bbb49f10c7e..b30282f2ab41f 100644
--- a/onnxruntime/test/python/quantization/op_test_utils.py
+++ b/onnxruntime/test/python/quantization/op_test_utils.py
@@ -358,6 +358,7 @@ def check_model_correctness(
model_onnx = onnx.load(f)
ops_set = set(node.op_type for node in model_onnx.graph.node)
check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"})
+ check_target_evaluator = False
with open(model_path_to_check, "rb") as f:
model_check = onnx.load(f)
@@ -413,7 +414,7 @@ def check_model_correctness(
check_sign_f8_quantization(model_path_origin, model_path_to_check)
# Verifies the expected outputs.
- if check_reference_evaluator and onnx_recent_enough:
+ if check_target_evaluator and onnx_recent_enough:
if op_matmul:
reference_new_ops = [QLinearMatMul]
else:
diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
index 73dae08af8ece..88e5052db4e2e 100644
--- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
+++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
@@ -125,7 +125,10 @@ def quant_test(
from onnxruntime.quantization import matmul_4bits_quantizer
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
- quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric)
+ quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
+ block_size=block_size, is_symmetric=is_symmetric
+ )
+ quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
quant.process()
quant.model.save_model_to_file(model_int4_path, False)
@@ -165,6 +168,9 @@ def quant_test_with_algo(
elif algorithm == "GPTQ":
# test GPTQ algorithm
algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
+ elif algorithm == "HQQ":
+ # test HQQ algorithm
+ algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
@@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self):
data_reader = self.input_feeds(1, {"input": [100, 52]})
self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)
+ @unittest.skipIf(
+ find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
+ )
+ def test_quantize_matmul_int4_using_hqq_algo(self):
+ if not find_spec("torch"):
+ self.skipTest("skip test_hqq_quant since torch is not installed")
+ model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
+ self.construct_model_matmul(model_fp32_path, symmetric=False)
+ data_reader = self.input_feeds(1, {"input": [100, 52]})
+ self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False)
+
if __name__ == "__main__":
unittest.main()