From bb2a6df8f737e177e2518ce0c4349c4dd0ea2928 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Thu, 14 Dec 2023 05:04:22 +0000 Subject: [PATCH] optimize int4 gemv kernel with cuda --- .../cuda/quantization/matmul_nbits.cu | 247 ++++++++++++------ .../kernel_explorer/kernels/matmul_4bits.py | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 15 +- 3 files changed, 186 insertions(+), 78 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index f2600a506285d..c94bbca88cf22 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -17,33 +17,109 @@ namespace onnxruntime { namespace contrib { namespace cuda { -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { +template +__device__ __forceinline__ T WarpUniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +// Convert 8 4bits integer stored in one uint32_t to 8 halfs. +// 8 4bits with order 0,1,2,3,4,5,6,7,8 will be converted to 8 halfs with order 0,4,1,5,2,6,3,7 +__device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* half_2x4) { + uint32_t* h = reinterpret_cast(half_2x4); + + // From https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + // First, we extract the i4s and construct an intermediate fp16 number. + constexpr uint32_t kImmLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t kBottomMask = 0x000f000f; + constexpr uint32_t kTopMask = 0x00f000f0; + constexpr uint32_t kI4sToF16sMagicNum = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = value >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(value), "n"(kBottomMask), "n"(kI4sToF16sMagicNum), "n"(kImmLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(value), "n"(kTopMask), "n"(kI4sToF16sMagicNum), "n"(kImmLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(kBottomMask), "n"(kI4sToF16sMagicNum), "n"(kImmLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(kTopMask), "n"(kI4sToF16sMagicNum), "n"(kImmLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1024, 1024} represented as an integer. + constexpr uint32_t kFp16TopMagicNum = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + constexpr uint32_t kOneSixteenth = 0x2c002c00; + // This is the half2 {-64, -64} represented as an integer. + constexpr uint32_t kNeg64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(kFp16TopMagicNum)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(kOneSixteenth), "r"(kNeg64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(kFp16TopMagicNum)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); +} + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; uint4 vec_a = *(reinterpret_cast(a)); - half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)); - half2 v0 = element01 * scale_half2 + zp_adjust2; + constexpr uint32_t kLowHalf2 = 0x5410; + constexpr uint32_t kHighHalf2 = 0x7632; + + uint4 vec_permuted; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.x) : "r"(vec_a.x), "r"(vec_a.z), "r"(kLowHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.y) : "r"(vec_a.x), "r"(vec_a.z), "r"(kHighHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.z) : "r"(vec_a.y), "r"(vec_a.w), "r"(kLowHalf2)); + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(vec_permuted.w) : "r"(vec_a.y), "r"(vec_a.w), "r"(kHighHalf2)); - half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)); - half2 v1 = element23 * scale_half2 + zp_adjust2; + half2 elements[4]; // [04, 15, 26, 37] - half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)); - half2 v2 = element45 * scale_half2 + zp_adjust2; + Convert8xInt4To8xHalfs(values_quant, elements); - half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)); - half2 v3 = element67 * scale_half2 + zp_adjust2; + half2 v0 = elements[0] * scale_half2 + zp_adjust2; + half2 v1 = elements[1] * scale_half2 + zp_adjust2; + half2 v2 = elements[2] * scale_half2 + zp_adjust2; + half2 v3 = elements[3] * scale_half2 + zp_adjust2; - v0 = v0 * (*(reinterpret_cast(&(vec_a.x)))); - v1 = v1 * (*(reinterpret_cast(&(vec_a.y)))); - v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0; - v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1; - v3 = v2 + v3; - return float(v3.x) + float(v3.y); + half2* sums_half2 = reinterpret_cast(sums); + sums_half2[0] += v0 * (*(reinterpret_cast(&(vec_permuted.x)))); + sums_half2[1] += v1 * (*(reinterpret_cast(&(vec_permuted.y)))); + sums_half2[2] += v2 * (*(reinterpret_cast(&(vec_permuted.z)))); + sums_half2[3] += v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); @@ -57,15 +133,14 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; - v0 = v0 * a_vec_0.x; - v1 = v1 * a_vec_0.y; - v2 = v2 * a_vec_0.z; - v3 = v3 * a_vec_0.w; - v4 = v4 * a_vec_1.x + v0; - v5 = v5 * a_vec_1.y + v1; - v6 = v6 * a_vec_1.z + v2; - v7 = v7 * a_vec_1.w + v3; - return v4 + v5 + v6 + v7; + sums[0] += v0 * a_vec_0.x; + sums[1] += v1 * a_vec_0.y; + sums[2] += v2 * a_vec_0.z; + sums[3] += v3 * a_vec_0.w; + sums[4] += v4 * a_vec_1.x; + sums[5] += v5 * a_vec_1.y; + sums[6] += v6 * a_vec_1.z; + sums[7] += v7 * a_vec_1.w; } constexpr int kColsPerThreadBlock = 8; @@ -76,8 +151,8 @@ constexpr int kWarpSize = 32; // The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1) // Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], // i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) -template -__global__ void MatMulFloatInt4Kernel( +template +__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt4Kernel( T* output, const T* a_data, const uint8_t* b_data_quant, @@ -87,63 +162,78 @@ __global__ void MatMulFloatInt4Kernel( int n, int k, int blocks_per_K) { - int n_block_id = blockIdx.x; - int m_id = blockIdx.y; - int lane_id = threadIdx.x; - int warp_id = threadIdx.y; - int n_id = n_block_id * kColsPerThreadBlock + warp_id; - int thread_id = warp_id * kWarpSize + lane_id; + const int n_block_id = blockIdx.x; + const int m_id = blockIdx.y; + const int lane_id = threadIdx.x; + const int warp_id = WarpUniform(threadIdx.y); + const int n_id = n_block_id * kColsPerThreadBlock + warp_id; constexpr int k_per_iter = 256; - int k_iter = k / k_per_iter; - - // blocks_per_k is the number of scales and zero points on the k dim - const int b_zp_k = (blocks_per_K + 1)/ 2; extern __shared__ char shared_buffer[]; - // load scale to shared buffer T* b_scale_vec = (T*)shared_buffer; - uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); int offset = n_block_id * kColsPerThreadBlock * blocks_per_K; - for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + for (int i = warp_id * kWarpSize + lane_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { b_scale_vec[i] = scales_data[offset + i]; } - int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; - for (int i = thread_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { - b_zp_vec[i] = zero_points != nullptr ? zero_points[zp_offset + i] : uint8_t(0x88); + const int b_zp_k = (blocks_per_K + 1) / 2; + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); + if constexpr (has_zero_point) { + int zp_offset = n_block_id * kColsPerThreadBlock * b_zp_k; + for (int i = warp_id * kWarpSize + lane_id; i < kColsPerThreadBlock * b_zp_k; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[2 * i] = (zero_points[zp_offset + i] & 0x0f); + b_zp_vec[2 * i + 1] = (zero_points[zp_offset + i] >> 4); + } + b_zp_vec += warp_id * b_zp_k * 2; } __syncthreads(); - a_data += m_id * k; - b_data_quant += n_id * blocks_per_K * (block_size / 2); + a_data += m_id * k + (lane_id << 3); - const int scale_col_offset = warp_id * blocks_per_K; - const int zp_col_offset = warp_id * b_zp_k; + b_scale_vec += warp_id * blocks_per_K; - float sum = 0.f; + T sums[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; int k_id = 0; - for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { - const int t_k = k_id + (lane_id << 3); // k index for this thread - const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point - uint32_t value = *(reinterpret_cast(b_data_quant + (t_k >> 1))); - T scale = b_scale_vec[scale_col_offset + t_meta_k]; - uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; - zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); - sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); - } + int t_meta_k = lane_id * 8 / block_size; + b_data_quant += n_id * blocks_per_K * (block_size / 2) + lane_id * 4; + +#define UnRollReduction(unroll_size) \ + do { \ + constexpr int kUnroll = unroll_size; \ + constexpr int kUnrollMask = 0xffffffff & (~(kUnroll * k_per_iter - 1)); \ + for (; k_id < (k & kUnrollMask); k_id += kUnroll * k_per_iter) { \ + _Pragma("unroll") for (int i = 0; i < kUnroll; i++) { \ + uint32_t value = *(reinterpret_cast(b_data_quant + k_per_iter / 2 * i)); \ + T scale = b_scale_vec[t_meta_k + k_per_iter / block_size * i]; \ + uint8_t zp = 8; \ + if constexpr (has_zero_point) { \ + zp = b_zp_vec[t_meta_k + k_per_iter / block_size * i]; \ + } \ + AccumulateEightElements(value, scale, zp, a_data + k_id + i * k_per_iter, sums); \ + } \ + b_data_quant += k_per_iter / 2 * kUnroll; \ + t_meta_k += k_per_iter / block_size * kUnroll; \ + } \ + } while (false) + + UnRollReduction(16); + UnRollReduction(4); + UnRollReduction(1); +#undef UnRollReduction // handle reminder if (k_id + lane_id * 8 < k) { - const int t_k = k_id + (lane_id << 3); // k index for this thread - const int t_meta_k = t_k / block_size; // k index for this thread, points to the scale and zero point - uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); - T scale = b_scale_vec[scale_col_offset + t_meta_k]; - uint8_t zp = b_zp_vec[zp_col_offset + t_meta_k/2]; - zp = (t_meta_k & 0x01) ? (zp >> 4) : (zp & 0x0f); - sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + uint32_t value = *(reinterpret_cast(b_data_quant)); + T scale = b_scale_vec[t_meta_k]; + uint8_t zp = 8; + if constexpr (has_zero_point) { + zp = b_zp_vec[t_meta_k]; + } + AccumulateEightElements(value, scale, zp, a_data + k_id, sums); } + float sum = (float)(sums[0] + sums[1] + sums[2] + sums[3] + sums[4] + sums[5] + sums[6] + sums[7]); // warp reduction for (int i = 16; i > 0; i = i / 2) { sum += __shfl_down_sync(0xffffffff, sum, i); @@ -152,7 +242,7 @@ __global__ void MatMulFloatInt4Kernel( if (lane_id == 0) { output[m_id * n + n_id] = sum; } -} +} // namespace cuda template bool TryMatMul4Bits( @@ -173,28 +263,35 @@ bool TryMatMul4Bits( dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); dim3 threads(kWarpSize, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; - int blocks_per_thread_block = blocks_per_K * kColsPerThreadBlock; - int shared_mem_size = sizeof(T) * blocks_per_thread_block + blocks_per_thread_block / 2; + int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); if (shared_mem_size > shared_mem_per_block) { return false; } +#define MatMulFloatInt4KernelDispatch(block_size) \ + if (nullptr != zero_points) { \ + MatMulFloatInt4Kernel<<>>( \ + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); \ + } else { \ + MatMulFloatInt4Kernel<<>>( \ + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); \ + } + if (16 == block_size) { - MatMulFloatInt4Kernel<<>>( - output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + MatMulFloatInt4KernelDispatch(16); } else if (32 == block_size) { - MatMulFloatInt4Kernel<<>>( - output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + MatMulFloatInt4KernelDispatch(32); } else if (64 == block_size) { - MatMulFloatInt4Kernel<<>>( - output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + MatMulFloatInt4KernelDispatch(64); } else if (128 == block_size) { - MatMulFloatInt4Kernel<<>>( - output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + MatMulFloatInt4KernelDispatch(128); } else { ORT_THROW("block size ", block_size, " is not supported"); } +#undef MatMulFloatInt4KernelDispatch + return true; } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py index 111e156cd6d01..b3bda914465b4 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -109,7 +109,7 @@ def profile(): dims_m = [1] for dt in dtypes: for m in dims_m: - for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096), (4096, 11008), (11008, 4096),(2*11008, 4096),): profile_with_args(m, n, k, dt, False) print() diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3c6217915bef0..0957c04d05d4b 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -62,7 +62,8 @@ void QuantizeDequantize(std::vector& raw_vals, tp.get()); } -void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zeropoint, bool use_float16) { +void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zeropoint, + bool use_float16, float fp16_abs_error = 0.02f) { RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -117,7 +118,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); - test.SetOutputAbsErr("Y", 0.02f); + test.SetOutputAbsErr("Y", fp16_abs_error); std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -163,6 +164,16 @@ TEST(MatMulNBits, Float16) { } } +TEST(MatMulNBits, Float16Large) { + for (auto block_size : {16, 32, 64, 128}) { + for (auto symmetric : {false, true}) { + RunTest(1, 4096, 4096, block_size, symmetric, true, 0.05); + RunTest(1, 4096, 11008, block_size, symmetric, true, 0.05); + RunTest(1, 11008, 4096, block_size, symmetric, true, 0.05); + } + } +} + #endif } // namespace test } // namespace onnxruntime