Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul int4 for CUDA #17526

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b96e30
int4 support on GPU
yufenglee Sep 6, 2023
7be5564
change quant tool
yufenglee Sep 11, 2023
2d8c8f9
use fp32 as accumulator
yufenglee Sep 12, 2023
19ecb96
refine the matmul_int4 kernel
yufenglee Sep 12, 2023
0a01dce
refine the matmul_int4 kernel
yufenglee Sep 12, 2023
74c6b80
refine benchmark tool
yufenglee Sep 18, 2023
0c3f6c5
refine the dequant int4
yufenglee Sep 18, 2023
155d4b2
optimize dequant
yufenglee Sep 25, 2023
3ff52ed
refine quant tool
yufenglee Sep 26, 2023
453207f
add pybind for blockwise quant
yufenglee Sep 27, 2023
1c7f9d5
fix build breaks
yufenglee Sep 27, 2023
7eba97e
refine quant tool
yufenglee Sep 27, 2023
dda154f
fix fp16
yufenglee Sep 27, 2023
2050392
add unit test for QuantBlockwise pybind
yufenglee Sep 28, 2023
c015d4d
refine the quant tool
yufenglee Sep 28, 2023
2034ac9
add option to exlude logit layer
yufenglee Sep 28, 2023
8d972dd
handle subgraph properly
yufenglee Oct 2, 2023
78d0f02
fix build break
yufenglee Oct 6, 2023
068f4db
change matmul 4bits name
yufenglee Sep 29, 2023
fef4216
change zp to 4bits
yufenglee Oct 5, 2023
3375903
revert change in matmul_weight4_quantizer.py
yufenglee Oct 9, 2023
4d8dbc0
format code
yufenglee Oct 9, 2023
9585dca
fix build/test breaks
yufenglee Oct 9, 2023
68191b7
fix break break
yufenglee Oct 9, 2023
bccf3d7
fix test failures in training pipeline
yufenglee Oct 9, 2023
a3abfcd
fix build break in traning CIs
yufenglee Oct 10, 2023
4e29283
fix training CIs
yufenglee Oct 10, 2023
d4e4145
fix training CIs
yufenglee Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
73 changes: 73 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulNBits</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
* <a href="#com.microsoft.MultiHeadAttention">com.microsoft.MultiHeadAttention</a>
Expand Down Expand Up @@ -2593,6 +2594,78 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MatMulNBits"></a><a name="com.microsoft.matmulnbits">**com.microsoft.MatMulNBits**</a>

MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits

For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}

Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4


#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>bits</tt> : int (required)</dt>
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 4)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional data blob</dd>
<dt><tt>scales</tt> : T1</dt>
<dd>quantization scale</dd>
<dt><tt>zero_points</tt> (optional) : T2</dt>
<dd>quantization zero points</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>tensor. The output tensor has the same rank as the input. </dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MaxpoolWithMask"></a><a name="com.microsoft.maxpoolwithmask">**com.microsoft.MaxpoolWithMask**</a>

For internal use.
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -844,6 +845,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -262,6 +263,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
#endif
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <algorithm>
#include <cmath>

namespace onnxruntime {
namespace contrib {

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

template <typename T, int32_t block_size, int32_t bits>
struct alignas(1) BlockwiseQuantBlock {
static_assert(block_size % 8 == 0);

uint8_t blob_data[block_size / 8 * bits];

FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const;
FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const;

FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N);
FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N);
};

template <typename T, int32_t block_size>
struct alignas(1) BlockwiseQuantBlock<T, block_size, 4> {
static_assert(block_size % 8 == 0);

uint8_t blob_data[block_size / 2];

FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const {
for (int i = 0; i < block_size; i += 2) {
T zp_t = static_cast<T>(float(zp));
if (k_idx + i < K) {
T x0 = static_cast<T>(float(blob_data[i / 2] & 0xF));
dst[i] = scale * (x0 - zp_t);
}
if (k_idx + i + 1 < K) {
T x1 = static_cast<T>(float(blob_data[i / 2] >> 4));
dst[i + 1] = scale * (x1 - zp_t);
}
}
}

FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const {
constexpr uint8_t zp = 8;
dequant(dst, scale, zp, k_idx, K);
}

FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) {
float min = static_cast<float>(*src);
float max = static_cast<float>(*src);
int32_t klen = std::min(block_size, K - k_idx);
for (int32_t kk = 0; kk < klen; kk++) {
const float v = static_cast<float>(src[N * kk]);
if (v < min) min = v;
if (v > max) max = v;
}
min = std::min(min, 0.0f);
max = std::max(max, 0.0f);

const float scale = (max - min) / ((1 << 4) - 1);
scale_block = static_cast<T>(scale);

const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;
float zero_point_fp = min;
if (scale != 0.0f) {
zero_point_fp = 0.f - min / scale;
}

// Handle any clamping
if (zero_point_fp < 0.0f) {
zp = 0;
} else if (zero_point_fp > 15.0f) {
zp = 15;
} else {
zp = (uint8_t)roundf(zero_point_fp);
}

for (int32_t kk = 0; kk < klen; kk += 2) {
const float v0 = static_cast<float>(src[N * kk]);
const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp)));

const float v1 = static_cast<float>((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f);
const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp)));

blob_data[kk / 2] = vi0 | (vi1 << 4);
}
}

FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) {
float amax = 0.0f; // abs(max)
float max = 0.0f;

int32_t klen = std::min(block_size, K - k_idx);

for (int32_t kk = 0; kk < klen; kk++) {
const float v = static_cast<float>(src[N * kk]);
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}

const float scale = max / (-8.f);
scale_block = static_cast<T>(scale);
const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;

for (int32_t kk = 0; kk < klen; kk += 2) {
const float v0 = src[N * kk] * reciprocal_scale;
const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f)));

const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0;
const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f)));

blob_data[kk / 2] = vi0 | (vi1 << 4);
}
}
};

} // namespace contrib
} // namespace onnxruntime
Loading