Skip to content

Commit

Permalink
dequant tail adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 1, 2023
1 parent 53e703e commit 3a565c8
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <
bool Columnwise>
__global__
void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* scales,
const uint8_t* zero_points, int rows, int columns, int thrd_col_blks) {
const uint8_t* zero_points, int rows, int columns, int thrd_row_blks) {

Check warning on line 62 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:62: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
using QuantBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::QuantBlk;
using ThreadBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::ThreadBlk;
Expand All @@ -75,8 +75,8 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
// quantized matrix is stored in column major, packed by column
const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8;

int32_t r_blk_idx = static_cast<int32_t>(block_idx / thrd_col_blks);
int32_t c_blk_idx = static_cast<int32_t>(block_idx % thrd_col_blks);
int32_t r_blk_idx = static_cast<int32_t>(block_idx % thrd_row_blks);
int32_t c_blk_idx = static_cast<int32_t>(block_idx / thrd_row_blks);

int32_t r = r_blk_idx * ThreadBlk::kRow;
int32_t c = c_blk_idx * ThreadBlk::kColumn;
Expand All @@ -97,6 +97,7 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca

const int32_t meta_col = c / QuantBlk::kColumn;
for (int32_t j = c; j < c_end; ++j) {
const uint8_t* q_ptr = weights + j * q_rows;
for (int32_t i = r; i < (r_end - 1); i += 2) {

Check warning on line 102 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:102: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow];
Expand All @@ -105,7 +106,7 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];;
const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow];

const auto vi = weights[j * q_rows + i / 2];
const auto vi = q_ptr[i / 2];

if constexpr (std::is_same<ElementT, half>::value){

Check warning on line 111 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing space before { [whitespace/braces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:111: Missing space before { [whitespace/braces] [5]
half2 scale_half2 = {scale0, scale1};
Expand All @@ -117,30 +118,22 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
dst[j * rows + i] = results.x;
dst[j * rows + (i + 1)] = results.y;
} else {

static_assert(std::is_same<ElementT, float>::value, "Only float and half are supported!");
const uint8_t vi0 = vi & 0xf;
const uint8_t vi1 = vi >> 4;
dst[j * rows + i] = static_cast<float>(vi0) * scale0 + adjust0;;
dst[j * rows + (i + 1)] = static_cast<float>(vi1) * scale1 + adjust1;
}

}

if (r_end & 1){
const int32_t meta_row = (r_end - 1) / QuantBlk::kRow;

const float scale0 = static_cast<float>(scales[meta_col * row_blks + meta_row]);
const int zp_pair = (zero_points == nullptr)
? 0x88
: zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2];
const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf);
if ((r_end & 1) && (r_end > r)){

Check warning on line 129 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing space before { [whitespace/braces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:129: Missing space before { [whitespace/braces] [5]
const auto scale0 = scale_buf[(r_end - 1 - r) / QuantBlk::kRow];
const auto adjust0 = adjust_buf[(r_end - 1 - r) / QuantBlk::kRow];

const auto vi = weights[j * q_rows + (r_end - 1) / 2];
const auto vi = q_ptr[(r_end - 1) / 2];
const uint8_t vi0 = vi & 0xf;
const float v0 = (static_cast<float>(vi0) - zp0) * scale0;

dst[j * rows + (r_end - 1)] = static_cast<ElementT>(v0);
dst[j * rows + (r_end - 1)] = static_cast<ElementT>(vi0) * scale0 + adjust0;
}
}
}
Expand Down Expand Up @@ -169,7 +162,7 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc
zero_points,
rows,
columns,
thrd_col_blks);
thrd_row_blks);
}


Expand Down

0 comments on commit 3a565c8

Please sign in to comment.