diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu index b88c303e21d8f..db02a4214da55 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu @@ -17,24 +17,28 @@ namespace onnxruntime { namespace contrib { namespace cuda { -inline __device__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; - const half2* a_half2 = reinterpret_cast(a); + uint4 vec_a = *(reinterpret_cast(a)); + half2 v0 = __hfma2(__halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)), scale_half2, zp_adjust2); half2 v1 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)), scale_half2, zp_adjust2); half2 v2 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)), scale_half2, zp_adjust2); half2 v3 = __hfma2(__halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)), scale_half2, zp_adjust2); - v0 = __hmul2(v0, a_half2[0]); - v1 = __hmul2(v1, a_half2[1]); - v2 = __hfma2(v2, a_half2[2], v0); - v3 = __hfma2(v3, a_half2[3], v1); + v0 = __hmul2(v0, *(reinterpret_cast(&(vec_a.x)))); + v1 = __hmul2(v1, *(reinterpret_cast(&(vec_a.y)))); + v2 = __hfma2(v2, *(reinterpret_cast(&(vec_a.z))), v0); + v3 = __hfma2(v3, *(reinterpret_cast(&(vec_a.w))), v1); v3 = __hadd2(v2, v3); return float(v3.x) + float(v3.y); } -inline __device__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + float zp_adjust = -scale * zp; float v0 = float(values_quant & 0xF) * scale + zp_adjust; float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust; @@ -44,14 +48,15 @@ inline __device__ float AccumulateEightElements(uint32_t values_quant, float sca float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust; float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; - v0 = v0 * a[0]; - v1 = v1 * a[1]; - v2 = v2 * a[2]; - v3 = v3 * a[3]; - v4 = v4 * a[4] + v0; - v5 = v5 * a[5] + v1; - v6 = v6 * a[6] + v2; - v7 = v7 * a[7] + v3; + + v0 = v0 * a_vec_0.x; + v1 = v1 * a_vec_0.y; + v2 = v2 * a_vec_0.z; + v3 = v3 * a_vec_0.w; + v4 = v4 * a_vec_1.x + v0; + v5 = v5 * a_vec_1.y + v1; + v6 = v6 * a_vec_1.z + v2; + v7 = v7 * a_vec_1.w + v3; return v4 + v5 + v6 + v7; } @@ -92,24 +97,20 @@ __global__ void MatMulFloatInt4Kernel( b_data_quant += n_id * group_count * (group_size / 2); float sum = 0.f; - for (int k_step = 0; k_step < k_iter; k_step++) { - uint32_t value = *(reinterpret_cast(b_data_quant + k_step * 128 + lane_id * 4)); - T scale = b_scale_vec[warp_id * group_count + (k_step * 256 + lane_id * 8) / group_size]; - uint8_t zp = b_zp_vec[warp_id * group_count + (k_step * 256 + lane_id * 8) / group_size]; - sum += AccumulateEightElements(value, scale, zp, a_data + (lane_id << 3)); - a_data += 256; + int k_id = 0; + for (; k_id < (k & 0xffffff00); k_id += 256) { + uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); + T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } // handle reminder - int k_id = k_iter * 256; - int k_remainder = k - k_iter * 256; - if (k_remainder > 0) { - if (lane_id * 8 < k_remainder) { - uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); - T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; - uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; - sum += AccumulateEightElements(value, scale, zp, a_data + (lane_id << 3)); - } + if (k_id + lane_id * 8 < k) { + uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); + T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size]; + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); } // warp reduction @@ -141,8 +142,6 @@ bool TryMatMul4BitsWeight( dim3 threads(32, 8); int shared_mem_size = (sizeof(T) + 1) * ((k + group_size - 1) / group_size * 8); - // printf("group size %d\n", group_size); - // printf("shared_mem_size %d\n", shared_mem_size); if (16 == group_size) { MatMulFloatInt4Kernel<<>>( output, a_data, b_data_quant, scales_data, zero_points, m, n, k);