Skip to content

Commit

Permalink
[quant] supports act_order inputs in Matmulnbits and new quantization…
Browse files Browse the repository at this point in the history
… algorithm "hqq" (#19106)

### Description
<!-- Describe your changes. -->
1. Support quantized GPTQ weight in huggingface like
[TheBloke/Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ)
2. Support Act_order for GPTQ
3. Support [HQQ](https://mobiusml.github.io/hqq_blog/) algorithm to
quantize matmul weight and add quant script



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wejoncy authored Mar 5, 2024
1 parent 2a5c9b8 commit 7e613ee
Show file tree
Hide file tree
Showing 14 changed files with 942 additions and 234 deletions.
43 changes: 25 additions & 18 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
- for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
4bit example:
|.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
- for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
3bit example:
|.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
The last uint_8 may have some bits unused.

For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}

Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4

Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
- [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
If zero_points has same type as A, it's not packed and has the same shape as Scales.

#### Version

Expand All @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 4)
#### Inputs (3 - 5)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional data blob</dd>
<dd>1 or 2 dimensional data blob</dd>
<dt><tt>scales</tt> : T1</dt>
<dd>quantization scale</dd>
<dt><tt>zero_points</tt> (optional) : T2</dt>
<dt><tt>zero_points</tt> (optional) : T3</dt>
<dd>quantization zero points</dd>
<dt><tt>g_idx</tt> (optional) : T4</dt>
<dd>group_idx</dd>
</dl>

#### Outputs
Expand All @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
<dt><tt>T2</tt> : tensor(uint8), tensor(int32)</dt>
<dd>Constrain quantized weight types to uint8/int32.</dd>
<dt><tt>T3</tt> : tensor(uint8), tensor(int32), tensor(float16), tensor(float)</dt>
<dd>Constrain quantized zero point types to uint8/int32/float16/float.</dd>
<dt><tt>T4</tt> : tensor(int32)</dt>
<dd>the index tensor.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -855,7 +855,7 @@ Do not modify directly.*
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
105 changes: 85 additions & 20 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#include <cstdint>
#include <type_traits>

#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
Expand Down Expand Up @@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
}
} // namespace

bool GetType(const NodeArg& node_arg, int32_t& type) {
type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
const auto* type_proto = node_arg.TypeAsProto();
if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) {
return false;
}

type = type_proto->tensor_type().elem_type();
return true;
}

class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
Expand All @@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel {
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
const auto& node = info.node();
auto input_defs = node.InputDefs();
// g_idx
if (input_defs.size() > 4) {
act_order_ = true;
}
int32_t type;
if (input_defs.size() > 3 && GetType(*input_defs[3], type)) {
zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
}

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
Expand Down Expand Up @@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel {
const size_t N_;
const size_t block_size_;
const size_t nbits_;
bool act_order_{false};
bool zero_point_is_not_quant_{false};
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
Expand All @@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

if (act_order_ || zero_point_is_not_quant_) {
return Status::OK();
}
#if defined(ORT_NEURAL_SPEED)

if (!all_constant_) {
Expand Down Expand Up @@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

Expand Down Expand Up @@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
#endif // defined(ORT_NEURAL_SPEED)

const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input<Tensor>(3) : nullptr;
const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input<Tensor>(4) : nullptr;

const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();

TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
Expand All @@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });
const bool has_single_b_matrix =
(!act_order_) && (!zero_point_is_not_quant_) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });

if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
Expand Down Expand Up @@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const uint8_t* b_data = b->Data<uint8_t>();

const size_t ldb = helper.Ldb(true);

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
zero_points_data, // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);

if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType<float>())) {
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
} else {
ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now");
// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
if ((zero_points && zero_points->IsDataType<float>())) {
DequantizeBlockwise<float, float>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const float*>(zero_points_data), // quantization zero points
reorder_idx_data,
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
} else {
DequantizeBlockwise<float, uint8_t>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
reorder_idx_data,
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
}
}
#if 0 // for debug
auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
Expand Down Expand Up @@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX(
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T3", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<float>()})
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()),
MatMulNBits);

} // namespace contrib
Expand Down
Loading

0 comments on commit 7e613ee

Please sign in to comment.