Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul int4 for CUDA #17526

Closed
wants to merge 28 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b96e30
int4 support on GPU
yufenglee Sep 6, 2023
7be5564
change quant tool
yufenglee Sep 11, 2023
2d8c8f9
use fp32 as accumulator
yufenglee Sep 12, 2023
19ecb96
refine the matmul_int4 kernel
yufenglee Sep 12, 2023
0a01dce
refine the matmul_int4 kernel
yufenglee Sep 12, 2023
74c6b80
refine benchmark tool
yufenglee Sep 18, 2023
0c3f6c5
refine the dequant int4
yufenglee Sep 18, 2023
155d4b2
optimize dequant
yufenglee Sep 25, 2023
3ff52ed
refine quant tool
yufenglee Sep 26, 2023
453207f
add pybind for blockwise quant
yufenglee Sep 27, 2023
1c7f9d5
fix build breaks
yufenglee Sep 27, 2023
7eba97e
refine quant tool
yufenglee Sep 27, 2023
dda154f
fix fp16
yufenglee Sep 27, 2023
2050392
add unit test for QuantBlockwise pybind
yufenglee Sep 28, 2023
c015d4d
refine the quant tool
yufenglee Sep 28, 2023
2034ac9
add option to exlude logit layer
yufenglee Sep 28, 2023
8d972dd
handle subgraph properly
yufenglee Oct 2, 2023
78d0f02
fix build break
yufenglee Oct 6, 2023
068f4db
change matmul 4bits name
yufenglee Sep 29, 2023
fef4216
change zp to 4bits
yufenglee Oct 5, 2023
3375903
revert change in matmul_weight4_quantizer.py
yufenglee Oct 9, 2023
4d8dbc0
format code
yufenglee Oct 9, 2023
9585dca
fix build/test breaks
yufenglee Oct 9, 2023
68191b7
fix break break
yufenglee Oct 9, 2023
bccf3d7
fix test failures in training pipeline
yufenglee Oct 9, 2023
a3abfcd
fix build break in traning CIs
yufenglee Oct 10, 2023
4e29283
fix training CIs
yufenglee Oct 10, 2023
d4e4145
fix training CIs
yufenglee Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
optimize dequant
  • Loading branch information
yufenglee committed Oct 10, 2023
commit 155d4b2611e33634867cbec70dc461858618c7df
28 changes: 17 additions & 11 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cmath>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
@@ -21,12 +22,13 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, h
half2 scale_half2 = {scale, scale};
half zp_adjust = -scale * __short2half_rn(zp);
half2 zp_adjust2 = {zp_adjust, zp_adjust};
half2* output_half2 = reinterpret_cast<half2*>(output);

output_half2[0] = __hfma2(__halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)), scale_half2, zp_adjust2);
output_half2[1] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)), scale_half2, zp_adjust2);
output_half2[2] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)), scale_half2, zp_adjust2);
output_half2[3] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)), scale_half2, zp_adjust2);
alignas(16) half2 results[4];
results[0] = __hfma2(__halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)), scale_half2, zp_adjust2);
results[1] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)), scale_half2, zp_adjust2);
results[2] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)), scale_half2, zp_adjust2);
results[3] = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)), scale_half2, zp_adjust2);
*(reinterpret_cast<float4*>(output)) = *(reinterpret_cast<float4*>(results));
}

__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, uint8_t zp, float* output) {
@@ -48,12 +50,13 @@ __global__ void Dequantize4BitsKernel(
const T* scale_data,
const uint8_t* zero_points,
int block_size,
int blocks_per_tb) {
int block_id = blockIdx.x * blocks_per_tb + (threadIdx.x * 8) / block_size;
int element_offset = block_id * block_size + (threadIdx.x * 8) % block_size;
int blocks_per_tb,
int shift) {
int block_id = blockIdx.x * blocks_per_tb + ((threadIdx.x * 8)>>shift);
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1<<shift) - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
T zero_point = static_cast<T>(zero_points ? float(zero_points[block_id]) : 8.f);
T zero_point = static_cast<T>(zero_points ? zero_points[block_id] : (uint8_t)(8));

output = output + element_offset;
DequantizeEightElements(quant_value, scale, zero_point, output);
@@ -70,17 +73,20 @@ Status Dequantize4Bits(
int block_size,
cudaStream_t stream) {
ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
int blocks_per_tb = GridDim::maxThreadsPerBlock * 8 / block_size;
constexpr int element_per_thread = 8;
int blocks_per_tb = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int k_blocks = k / block_size;
int blocks_per_grid = static_cast<int>(CeilDiv(n * k_blocks, blocks_per_tb));
int shift = static_cast<int>(log2f(block_size));

Dequantize4BitsKernel<<<blocks_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
scales_data,
zero_points,
block_size,
blocks_per_tb);
blocks_per_tb,
shift);

return Status::OK();
}
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def profile_dequantize_int4_func(n, k, dtype, func):
f = getattr(ke, func)
my_op = f(output_d, quant_d, scales_d, n, k)
duration_ms = my_op.Profile()
total_bytes = 2 * (n * k) * (dtype_to_bytes(dtype))
total_bytes = (n * k) / 2 + (n * k + n * k / 32) * dtype_to_bytes(dtype)

ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k))