Skip to content

Commit

Permalink
Switch use of SMEM in kernel based on ROCm and CUDA version
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 9, 2023
1 parent 85b009c commit 8a1d330
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 44 deletions.
156 changes: 112 additions & 44 deletions exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out

#if defined(USE_SMEM)
const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0
#endif

typedef void (*fp_q4_matmul_kernel)
(
Expand Down Expand Up @@ -46,8 +48,12 @@ __global__ void q4_matmul_kernel
bool no_zero
)
{
extern __shared__ half2 x_cache[];
half* x_cache_h = (half*)x_cache;
#if defined(USE_SMEM)

extern __shared__ half2 x_cache[];
half* x_cache_h = (half*)x_cache;

#endif

// Start of block

Expand Down Expand Up @@ -87,57 +93,109 @@ __global__ void q4_matmul_kernel

for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
for (int i = threadIdx.x; i < groupsize; i += THREADS_X)
{
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
}
__syncthreads();

if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
__syncthreads();
#if defined(USE_SMEM)

for (int i = threadIdx.x; i < groupsize; i += THREADS_X)
{
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
}
__syncthreads();

if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
__syncthreads();

#else

if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;

if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, (const half2*) x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;

if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, groupsize / 8);
}

#endif
}
}
else
{
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache

for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
{
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
{
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
}
__syncthreads();
#if defined(USE_SMEM)

if constexpr (use_half2)
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
{
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
}
__syncthreads();

if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
}
__syncthreads();
}
else

#else

for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;

if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, (const half2*) x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;

if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, 1);
}
}
__syncthreads();
}

#endif

}

// Add to block result
Expand Down Expand Up @@ -226,8 +284,18 @@ void q4_matmul_cuda
);

fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);

#if defined(USE_SMEM)

int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);

# else

int shared_mem = 0;

#endif

kernel<<<blocks, threads, shared_mem, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}

void q4_matmul_recons_cuda
Expand Down
4 changes: 4 additions & 0 deletions exllama_ext/cuda_func/q4_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#define rocblas_handle hipblasHandle_t
#endif

#if !defined(USE_ROCM) && (!defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700))
#define USE_SMEM
#endif

void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
Expand Down
111 changes: 111 additions & 0 deletions exllama_ext/matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,115 @@ __device__ __forceinline__ half dot_product_8_h
return result;
}

// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map

__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;

for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;

half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);

half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);

half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];

half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);

half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}

return result;
}

__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;

for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;

half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);

half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}

return result;
}

#endif

0 comments on commit 8a1d330

Please sign in to comment.