diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 9f2851f38ca29..4cb1cea7958a6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -63,6 +63,9 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca using QuantBlk = typename BlkQuantTraits::QuantBlk; using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + // !! 4b specific code + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; @@ -71,7 +74,6 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca // quantized matrix is stored in column major, packed by column const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; - const auto q_cols = meta_cols * QuantBlk::kColumn; int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -82,36 +84,32 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca int32_t r_end = std::min(r + ThreadBlk::kRow, rows); int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets + const ElementT scale_buf[2] = { + scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], + ((r/QuantBlk::kRow) < (meta_rows - 1)) ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] : static_cast(0.0f)}; + const uint8_t zp_pair = (zero_points == nullptr) + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; + const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), + (-scale_buf[1]) * static_cast(zp_buf[1])}; + + const int32_t meta_col = c / QuantBlk::kColumn; for (int32_t j = c; j < c_end; ++j) { - const int32_t meta_col = j / QuantBlk::kColumn; - - // !! 4b specific code - // the whole loop is 4b specific due to sub 8 bit packing - // and unpacking. We can potentially make this qbits generic - // by wraping the packing/unpacking code like cutlass::Array - static_assert(qbits == 4, "Only 4b block quantization is supported!"); - for (int32_t i = r; i < (r_end - 1); i += 2) { - const int32_t meta_row = i / QuantBlk::kRow; - const auto scale0 = scales[meta_col * row_blks + meta_row]; - const uint16_t zp_pair = (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const uint16_t zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; - auto scale1 = scale0; - uint16_t zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = scales[meta_col * row_blks + meta_row + 1]; - zp1 = (zp_pair >> 4) & 0xf; - } + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; const auto vi = weights[j * q_rows + i / 2]; if constexpr (std::is_same::value){ half2 scale_half2 = {scale0, scale1}; - half2 zp_adjust2 = {(-scale0) * __ushort2half_rn(zp0), (-scale1) * __ushort2half_rn(zp1)}; + half2 zp_adjust2 = {adjust0, adjust1}; half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)}; half2 results = v * scale_half2 + zp_adjust2; @@ -123,8 +121,8 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca static_assert(std::is_same::value, "Only float and half are supported!"); const uint8_t vi0 = vi & 0xf; const uint8_t vi1 = vi >> 4; - dst[j * rows + i] = (static_cast(vi0) - zp0) * static_cast(scale0); - dst[j * rows + (i + 1)] = (static_cast(vi1) - zp1) * static_cast(scale1); + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; } }