From 3a565c8e2f612664e7d27cce5f0f38051766efb1 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:07:38 +0000 Subject: [PATCH] dequant tail adjustment --- .../cuda/quantization/dequantize_blockwise.cu | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 4cb1cea7958a6..7aa08eabd1ae1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -58,7 +58,7 @@ template < bool Columnwise> __global__ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* scales, - const uint8_t* zero_points, int rows, int columns, int thrd_col_blks) { + const uint8_t* zero_points, int rows, int columns, int thrd_row_blks) { using QuantBlk = typename BlkQuantTraits::QuantBlk; using ThreadBlk = typename BlkQuantTraits::ThreadBlk; @@ -75,8 +75,8 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca // quantized matrix is stored in column major, packed by column const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; - int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); - int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + int32_t r_blk_idx = static_cast(block_idx % thrd_row_blks); + int32_t c_blk_idx = static_cast(block_idx / thrd_row_blks); int32_t r = r_blk_idx * ThreadBlk::kRow; int32_t c = c_blk_idx * ThreadBlk::kColumn; @@ -97,6 +97,7 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca const int32_t meta_col = c / QuantBlk::kColumn; for (int32_t j = c; j < c_end; ++j) { + const uint8_t* q_ptr = weights + j * q_rows; for (int32_t i = r; i < (r_end - 1); i += 2) { const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; @@ -105,7 +106,7 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; - const auto vi = weights[j * q_rows + i / 2]; + const auto vi = q_ptr[i / 2]; if constexpr (std::is_same::value){ half2 scale_half2 = {scale0, scale1}; @@ -117,30 +118,22 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca dst[j * rows + i] = results.x; dst[j * rows + (i + 1)] = results.y; } else { - static_assert(std::is_same::value, "Only float and half are supported!"); const uint8_t vi0 = vi & 0xf; const uint8_t vi1 = vi >> 4; dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; } - } - if (r_end & 1){ - const int32_t meta_row = (r_end - 1) / QuantBlk::kRow; - - const float scale0 = static_cast(scales[meta_col * row_blks + meta_row]); - const int zp_pair = (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + if ((r_end & 1) && (r_end > r)){ + const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow]; + const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow]; - const auto vi = weights[j * q_rows + (r_end - 1) / 2]; + const auto vi = q_ptr[(r_end - 1) / 2]; const uint8_t vi0 = vi & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - dst[j * rows + (r_end - 1)] = static_cast(v0); + dst[j * rows + (r_end - 1)] = static_cast(vi0) * scale0 + adjust0; } } } @@ -169,7 +162,7 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc zero_points, rows, columns, - thrd_col_blks); + thrd_row_blks); }