Skip to content

Commit

Permalink
Optimize q4_matmul: add missing assignment back
Browse files Browse the repository at this point in the history
  • Loading branch information
QuarticCat committed Sep 5, 2023
1 parent caf4de8 commit 215a715
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions exllama_ext/cuda_func/q4_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ __global__ void q4_matmul_kernel
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
__syncthreads();
}
}
else
Expand All @@ -124,15 +125,16 @@ __global__ void q4_matmul_kernel
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
}
__syncthreads();
}
}

Expand Down

0 comments on commit 215a715

Please sign in to comment.