Skip to content

Commit

Permalink
refine the matmul_int4 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Sep 12, 2023
1 parent 8af2630 commit caa83a0
Showing 1 changed file with 31 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,28 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

inline __device__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) {
__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) {
half2 scale_half2 = {scale, scale};
half zp_adjust = -scale * __short2half_rn(zp);
half2 zp_adjust2 = {zp_adjust, zp_adjust};
const half2* a_half2 = reinterpret_cast<const half2*>(a);
uint4 vec_a = *(reinterpret_cast<const uint4*>(a));

half2 v0 = __hfma2(__halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)), scale_half2, zp_adjust2);
half2 v1 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)), scale_half2, zp_adjust2);
half2 v2 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)), scale_half2, zp_adjust2);
half2 v3 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)), scale_half2, zp_adjust2);
v0 = __hmul2(v0, a_half2[0]);
v1 = __hmul2(v1, a_half2[1]);
v2 = __hfma2(v2, a_half2[2], v0);
v3 = __hfma2(v3, a_half2[3], v1);
v0 = __hmul2(v0, *(reinterpret_cast<half2*>(&(vec_a.x))));
v1 = __hmul2(v1, *(reinterpret_cast<half2*>(&(vec_a.y))));
v2 = __hfma2(v2, *(reinterpret_cast<half2*>(&(vec_a.z))), v0);
v3 = __hfma2(v3, *(reinterpret_cast<half2*>(&(vec_a.w))), v1);
v3 = __hadd2(v2, v3);
return float(v3.x) + float(v3.y);
}

inline __device__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) {
__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) {
float4 a_vec_0 = *(reinterpret_cast<const float4*>(a));
float4 a_vec_1 = *(reinterpret_cast<const float4*>(a + 4));

float zp_adjust = -scale * zp;
float v0 = float(values_quant & 0xF) * scale + zp_adjust;
float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust;
Expand All @@ -44,14 +48,15 @@ inline __device__ float AccumulateEightElements(uint32_t values_quant, float sca
float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust;
float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust;
float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust;
v0 = v0 * a[0];
v1 = v1 * a[1];
v2 = v2 * a[2];
v3 = v3 * a[3];
v4 = v4 * a[4] + v0;
v5 = v5 * a[5] + v1;
v6 = v6 * a[6] + v2;
v7 = v7 * a[7] + v3;

v0 = v0 * a_vec_0.x;
v1 = v1 * a_vec_0.y;
v2 = v2 * a_vec_0.z;
v3 = v3 * a_vec_0.w;
v4 = v4 * a_vec_1.x + v0;
v5 = v5 * a_vec_1.y + v1;
v6 = v6 * a_vec_1.z + v2;
v7 = v7 * a_vec_1.w + v3;
return v4 + v5 + v6 + v7;
}

Expand Down Expand Up @@ -92,24 +97,20 @@ __global__ void MatMulFloatInt4Kernel(
b_data_quant += n_id * group_count * (group_size / 2);

float sum = 0.f;
for (int k_step = 0; k_step < k_iter; k_step++) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + k_step * 128 + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_step * 256 + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_step * 256 + lane_id * 8) / group_size];
sum += AccumulateEightElements(value, scale, zp, a_data + (lane_id << 3));
a_data += 256;
int k_id = 0;
for (; k_id < (k & 0xffffff00); k_id += 256) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + (k_id >> 1) + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
}

// handle reminder
int k_id = k_iter * 256;
int k_remainder = k - k_iter * 256;
if (k_remainder > 0) {
if (lane_id * 8 < k_remainder) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + k_iter * 128 + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
sum += AccumulateEightElements(value, scale, zp, a_data + (lane_id << 3));
}
if (k_id + lane_id * 8 < k) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + k_iter * 128 + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
}

// warp reduction
Expand Down Expand Up @@ -141,8 +142,6 @@ bool TryMatMul4BitsWeight(
dim3 threads(32, 8);
int shared_mem_size = (sizeof(T) + 1) * ((k + group_size - 1) / group_size * 8);

// printf("group size %d\n", group_size);
// printf("shared_mem_size %d\n", shared_mem_size);
if (16 == group_size) {
MatMulFloatInt4Kernel<T, 16><<<blocks, threads, shared_mem_size, stream>>>(
output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
Expand Down

0 comments on commit caa83a0

Please sign in to comment.