From 22ad629cf761e083336e15304c328eb413003763 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 13 Mar 2024 09:27:46 +0800 Subject: [PATCH] [bug fix] dequantize 4bit (#19793) ### Description ### Motivation and Context --- .../contrib_ops/cpu/quantization/matmul_nbits_impl.cc | 3 ++- .../contrib_ops/cuda/quantization/dequantize_blockwise.cu | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index f92e59e990ba5..7e343d85f4048 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -41,8 +41,9 @@ void Dequantize4BitsKernelReOrder( T* output_i = output + out_y * out_cols + out_x; uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); const int remain_x = std::min(8, out_cols - out_x); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1)); for (int i = 0; i < remain_x; i++) { - int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx; + int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx; T scale = *(scale_data + n_idx * scales_shape_x + rid); float zp_f = 8; if (zero_points) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index cd6593352008b..265adf22eeb61 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -23,7 +23,7 @@ namespace cuda { __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; - half zp_adjust = -scale * __short2half_rn(zp); + half zp_adjust = -scale * zp; half2 zp_adjust2 = {zp_adjust, zp_adjust}; alignas(16) half2 results[4]; @@ -83,8 +83,9 @@ __global__ void Dequantize4BitsKernelReOrder( int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); T* output_i = output + element_offset; uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1)); for (int i = 0; i < 8; i++) { - int32_t rid = reorder_idx[kb_idx * block_size + i]; + int32_t rid = reorder_idx_with_off[i]; T scale = *(scale_data + n_idx * scales_shape_x + rid); uint8_t zp = 8; if (zero_points) { @@ -157,7 +158,7 @@ Status Dequantize4Bits( int groups_per_K = k / block_size; int total_groups = n * groups_per_K; // total elemenets in quant_data int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); - if (!reorder_idx) { + if (!reorder_idx || std::is_same_v) { Dequantize4BitsKernel<<>>( output, quant_data,