Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] supports act_order inputs in Matmulnbits and new quantization algorithm "hqq" #19106

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
- for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
4bit example:
|.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
- for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
3bit example:
|.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
The last uint_8 may have some bits unused.

For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}

wejoncy marked this conversation as resolved.
Show resolved Hide resolved
Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4

Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
- [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
If zero_points has same type as A, it's not packed and has the same shape as Scales.

#### Version

Expand All @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 4)
#### Inputs (3 - 5)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional data blob</dd>
<dd>1 or 2 dimensional data blob</dd>
<dt><tt>scales</tt> : T1</dt>
<dd>quantization scale</dd>
<dt><tt>zero_points</tt> (optional) : T2</dt>
<dt><tt>zero_points</tt> (optional) : T3</dt>
<dd>quantization zero points</dd>
<dt><tt>g_idx</tt> (optional) : T4</dt>
<dd>group_idx</dd>
</dl>

#### Outputs
Expand All @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
<dt><tt>T2</tt> : tensor(uint8), tensor(int32)</dt>
<dd>Constrain quantized weight types to uint8/int32.</dd>
<dt><tt>T3</tt> : tensor(uint8), tensor(int32), tensor(float16), tensor(float)</dt>
<dd>Constrain quantized zero point types to uint8/int32/float16/float.</dd>
<dt><tt>T4</tt> : tensor(int32)</dt>
<dd>the index tensor.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -854,7 +854,7 @@ Do not modify directly.*
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
105 changes: 85 additions & 20 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#include <cstdint>
#include <type_traits>

#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
Expand Down Expand Up @@ -50,6 +56,17 @@
}
} // namespace

bool GetType(const NodeArg& node_arg, int32_t& type) {
type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
const auto* type_proto = node_arg.TypeAsProto();
if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) {
return false;
}

type = type_proto->tensor_type().elem_type();
return true;
}

class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
Expand All @@ -59,6 +76,17 @@
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
const auto& node = info.node();
auto input_defs = node.InputDefs();
// g_idx
if (input_defs.size() > 4) {
act_order_ = true;
}
int32_t type;
if (input_defs.size() > 3 && GetType(*input_defs[3], type)) {
zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
}

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
Expand Down Expand Up @@ -88,6 +116,8 @@
const size_t N_;
const size_t block_size_;
const size_t nbits_;
bool act_order_{false};
bool zero_point_is_not_quant_{false};
const int64_t accuracy_level_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
Expand All @@ -105,7 +135,9 @@
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

if (act_order_ || zero_point_is_not_quant_) {
return Status::OK();
}
#if defined(ORT_NEURAL_SPEED)

if (!all_constant_) {
Expand Down Expand Up @@ -212,7 +244,6 @@

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

Expand Down Expand Up @@ -257,11 +288,14 @@
#endif // defined(ORT_NEURAL_SPEED)

const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input<Tensor>(3) : nullptr;
const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input<Tensor>(4) : nullptr;

const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();

TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
Expand All @@ -281,8 +315,9 @@
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });
const bool has_single_b_matrix =
(!act_order_) && (!zero_point_is_not_quant_) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });

Check warning on line 320 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:320: Lines should be <= 120 characters long [whitespace/line_length] [2]

if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
Expand Down Expand Up @@ -328,22 +363,50 @@
const uint8_t* b_data = b->Data<uint8_t>();

const size_t ldb = helper.Ldb(true);

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
zero_points_data, // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);

if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType<float>())) {
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
} else {
ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now");
// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
if ((zero_points && zero_points->IsDataType<float>())) {
DequantizeBlockwise<float, float>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const float*>(zero_points_data), // quantization zero points
reorder_idx_data,
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
} else {
DequantizeBlockwise<float, uint8_t>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
reorder_idx_data,
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
}
}
#if 0 // for debug
auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
Expand Down Expand Up @@ -374,7 +437,9 @@
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T3", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<float>()})
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()),
MatMulNBits);

} // namespace contrib
Expand Down
Loading
Loading