diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index 3e88acb5..a1b7be70 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -11,7 +11,9 @@ const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out +#if defined(USE_SMEM) const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0 +#endif typedef void (*fp_q4_matmul_kernel) ( @@ -46,8 +48,12 @@ __global__ void q4_matmul_kernel bool no_zero ) { - extern __shared__ half2 x_cache[]; - half* x_cache_h = (half*)x_cache; + #if defined(USE_SMEM) + + extern __shared__ half2 x_cache[]; + half* x_cache_h = (half*)x_cache; + + #endif // Start of block @@ -87,57 +93,109 @@ __global__ void q4_matmul_kernel for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) { - for (int i = threadIdx.x; i < groupsize; i += THREADS_X) - { - if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); - else x_cache_h[i] = *x_.item_ptr(x_row, k + i); - } - __syncthreads(); - - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - __syncthreads(); + #if defined(USE_SMEM) + + for (int i = threadIdx.x; i < groupsize; i += THREADS_X) + { + if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); + else x_cache_h[i] = *x_.item_ptr(x_row, k + i); + } + __syncthreads(); + + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + __syncthreads(); + + #else + + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, (const half2*) x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + + #endif } } else { // Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache - for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP) - { - for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X) - { - if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); - else x_cache_h[i] = *x_.item_ptr(x_row, k + i); - } - __syncthreads(); + #if defined(USE_SMEM) - if constexpr (use_half2) + for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP) { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X) + { + if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); + else x_cache_h[i] = *x_.item_ptr(x_row, k + i); + } + __syncthreads(); + + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + } + __syncthreads(); } - else + + #else + + for (int k = x_column; k < x_column + iterations * 8; k += 8) { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, (const half2*) x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_.item_ptr(x_row, k), w_, k, w_column, w_scale, w_zero, 1); + } } - __syncthreads(); - } + + #endif + } // Add to block result @@ -226,8 +284,18 @@ void q4_matmul_cuda ); fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half); - kernel<<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); + + #if defined(USE_SMEM) + + int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half); + + # else + + int shared_mem = 0; + + #endif + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda diff --git a/exllama_ext/cuda_func/q4_matmul.cuh b/exllama_ext/cuda_func/q4_matmul.cuh index 8c90e821..1bd91250 100644 --- a/exllama_ext/cuda_func/q4_matmul.cuh +++ b/exllama_ext/cuda_func/q4_matmul.cuh @@ -16,6 +16,10 @@ #define rocblas_handle hipblasHandle_t #endif +#if !defined(USE_ROCM) && (!defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700)) +#define USE_SMEM +#endif + void q4_matmul_cuda ( ExLlamaTuning* tuningParams, diff --git a/exllama_ext/matrix.cuh b/exllama_ext/matrix.cuh index 7179a1f4..0c421c60 100644 --- a/exllama_ext/matrix.cuh +++ b/exllama_ext/matrix.cuh @@ -174,4 +174,115 @@ __device__ __forceinline__ half dot_product_8_h return result; } +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + #endif