Skip to content

Commit

Permalink
Optimize q4_matmul: optimize non-128 group sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
QuarticCat committed Sep 5, 2023
1 parent 215a715 commit 93dbb26
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_X = 128; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out

const int GROUP_STEP = 128; // Assumed group size when block_size_z % groupsize != 0

typedef void (*fp_q4_matmul_kernel)
(
const half*,
Expand Down Expand Up @@ -52,7 +54,7 @@ __global__ void q4_matmul_kernel
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));

int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X (32)
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;

int iterations = (x_column_end - x_column) / 8;
Expand Down Expand Up @@ -109,11 +111,11 @@ __global__ void q4_matmul_kernel
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache

for (int k = x_column; k < x_column + iterations * 8; k += 8)
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
{
for (int i = threadIdx.x; i < 8; i += THREADS_X)
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
{
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
Expand All @@ -125,14 +127,14 @@ __global__ void q4_matmul_kernel
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
}
__syncthreads();
}
Expand Down Expand Up @@ -224,7 +226,8 @@ void q4_matmul_cuda
);

fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, w->groupsize * sizeof(half), alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}

void q4_matmul_recons_cuda
Expand Down

0 comments on commit 93dbb26

Please sign in to comment.