Skip to content

Commit

Permalink
[bug fix] dequantize 4bit (#19793)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wejoncy authored Mar 13, 2024
1 parent 860eb76 commit 22ad629
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void Dequantize4BitsKernelReOrder(
T* output_i = output + out_y * out_cols + out_x;
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
const int remain_x = std::min(8, out_cols - out_x);
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
for (int i = 0; i < remain_x; i++) {
int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx;
int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx;
T scale = *(scale_data + n_idx * scales_shape_x + rid);
float zp_f = 8;
if (zero_points) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace cuda {

__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) {
half2 scale_half2 = {scale, scale};
half zp_adjust = -scale * __short2half_rn(zp);
half zp_adjust = -scale * zp;
half2 zp_adjust2 = {zp_adjust, zp_adjust};

alignas(16) half2 results[4];
Expand Down Expand Up @@ -83,8 +83,9 @@ __global__ void Dequantize4BitsKernelReOrder(
int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
T* output_i = output + element_offset;
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1));
for (int i = 0; i < 8; i++) {
int32_t rid = reorder_idx[kb_idx * block_size + i];
int32_t rid = reorder_idx_with_off[i];
T scale = *(scale_data + n_idx * scales_shape_x + rid);
uint8_t zp = 8;
if (zero_points) {
Expand Down Expand Up @@ -157,7 +158,7 @@ Status Dequantize4Bits(
int groups_per_K = k / block_size;
int total_groups = n * groups_per_K; // total elemenets in quant_data
int groups_per_grid = static_cast<int>(CeilDiv(total_groups, groups_per_threadblock));
if (!reorder_idx) {
if (!reorder_idx || std::is_same_v<ZeroT, T>) {
Dequantize4BitsKernel<T, ZeroT><<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
Expand Down

0 comments on commit 22ad629

Please sign in to comment.