From 93dbb26602b1a1ef9e638715112366f544216c79 Mon Sep 17 00:00:00 2001 From: QuarticCat Date: Wed, 6 Sep 2023 07:47:53 +0800 Subject: [PATCH] Optimize q4_matmul: optimize non-128 group sizes --- exllama_ext/cuda_func/q4_matmul.cu | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index e6c51f34..6a0f89c6 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -8,9 +8,11 @@ #include "../hip_compat.cuh" #endif -const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_X = 128; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out +const int GROUP_STEP = 128; // Assumed group size when block_size_z % groupsize != 0 + typedef void (*fp_q4_matmul_kernel) ( const half*, @@ -52,7 +54,7 @@ __global__ void q4_matmul_kernel int x_column = block_size_z * blockIdx.z; int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X (32) + int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X int x_row = THREADS_Y * blockIdx.y + threadIdx.y; int iterations = (x_column_end - x_column) / 8; @@ -109,11 +111,11 @@ __global__ void q4_matmul_kernel } else { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + // Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache - for (int k = x_column; k < x_column + iterations * 8; k += 8) + for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP) { - for (int i = threadIdx.x; i < 8; i += THREADS_X) + for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X) { if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); else x_cache_h[i] = *x_.item_ptr(x_row, k + i); @@ -125,14 +127,14 @@ __global__ void q4_matmul_kernel int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1); + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); } else { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1); + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); } __syncthreads(); } @@ -224,7 +226,8 @@ void q4_matmul_cuda ); fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - kernel<<groupsize * sizeof(half), alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); + int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half); + kernel<<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda