Skip to content

Commit

Permalink
optimize dequant
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 1, 2023
1 parent c4148e3 commit 53e703e
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
using QuantBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::QuantBlk;
using ThreadBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::ThreadBlk;

// !! 4b specific code
static_assert(qbits == 4, "Only 4b block quantization is supported!");

const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x;
const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow;

Expand All @@ -71,7 +74,6 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca

// quantized matrix is stored in column major, packed by column
const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8;
const auto q_cols = meta_cols * QuantBlk::kColumn;

int32_t r_blk_idx = static_cast<int32_t>(block_idx / thrd_col_blks);
int32_t c_blk_idx = static_cast<int32_t>(block_idx % thrd_col_blks);
Expand All @@ -82,36 +84,32 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
int32_t r_end = std::min(r + ThreadBlk::kRow, rows);
int32_t c_end = std::min(c + ThreadBlk::kColumn, columns);

Check warning on line 85 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:85: Add #include <algorithm> for min [build/include_what_you_use] [4]

// for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets
const ElementT scale_buf[2] = {
scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow],
((r/QuantBlk::kRow) < (meta_rows - 1)) ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] : static_cast<ElementT>(0.0f)};

Check warning on line 90 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:90: Lines should be <= 120 characters long [whitespace/line_length] [2]
const uint8_t zp_pair = (zero_points == nullptr)
? 0x88
: zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2];
const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)};
const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast<ElementT>(zp_buf[0]),
(-scale_buf[1]) * static_cast<ElementT>(zp_buf[1])};

const int32_t meta_col = c / QuantBlk::kColumn;
for (int32_t j = c; j < c_end; ++j) {
const int32_t meta_col = j / QuantBlk::kColumn;

// !! 4b specific code
// the whole loop is 4b specific due to sub 8 bit packing
// and unpacking. We can potentially make this qbits generic
// by wraping the packing/unpacking code like cutlass::Array
static_assert(qbits == 4, "Only 4b block quantization is supported!");

for (int32_t i = r; i < (r_end - 1); i += 2) {
const int32_t meta_row = i / QuantBlk::kRow;

Check warning on line 101 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:101: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
const auto scale0 = scales[meta_col * row_blks + meta_row];
const uint16_t zp_pair = (zero_points == nullptr)
? 0x88
: zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2];
const uint16_t zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf);
const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow];
const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow];

auto scale1 = scale0;
uint16_t zp1 = zp0;
if constexpr (QuantBlk::kRow == 1) {
scale1 = scales[meta_col * row_blks + meta_row + 1];
zp1 = (zp_pair >> 4) & 0xf;
}
const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];;
const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow];

const auto vi = weights[j * q_rows + i / 2];

if constexpr (std::is_same<ElementT, half>::value){

Check warning on line 110 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing space before { [whitespace/braces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:110: Missing space before { [whitespace/braces] [5]
half2 scale_half2 = {scale0, scale1};
half2 zp_adjust2 = {(-scale0) * __ushort2half_rn(zp0), (-scale1) * __ushort2half_rn(zp1)};
half2 zp_adjust2 = {adjust0, adjust1};

half2 v = {__ushort2half_rn(vi & 0xF), __ushort2half_rn((vi >> 4) & 0xF)};
half2 results = v * scale_half2 + zp_adjust2;
Expand All @@ -123,8 +121,8 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
static_assert(std::is_same<ElementT, float>::value, "Only float and half are supported!");
const uint8_t vi0 = vi & 0xf;
const uint8_t vi1 = vi >> 4;
dst[j * rows + i] = (static_cast<float>(vi0) - zp0) * static_cast<float>(scale0);
dst[j * rows + (i + 1)] = (static_cast<float>(vi1) - zp1) * static_cast<float>(scale1);
dst[j * rows + i] = static_cast<float>(vi0) * scale0 + adjust0;;
dst[j * rows + (i + 1)] = static_cast<float>(vi1) * scale1 + adjust1;
}

Check warning on line 127 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:127: Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
}
Expand Down

0 comments on commit 53e703e

Please sign in to comment.