From 215a71581b3f1408af536d273c725db918441392 Mon Sep 17 00:00:00 2001 From: QuarticCat Date: Wed, 6 Sep 2023 04:28:09 +0800 Subject: [PATCH] Optimize q4_matmul: add missing assignment back --- exllama_ext/cuda_func/q4_matmul.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index ef181fb3..e6c51f34 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -96,14 +96,15 @@ __global__ void q4_matmul_kernel { half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); } else { half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); } + __syncthreads(); } } else @@ -124,15 +125,16 @@ __global__ void q4_matmul_kernel int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1); + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1); } else { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1); + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1); } + __syncthreads(); } }