Skip to content

Commit

Permalink
split dequant impl
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 1, 2023
1 parent 3a565c8 commit b05046e
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 19 deletions.
119 changes: 118 additions & 1 deletion onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,124 @@ namespace contrib {
namespace cuda {


__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) {
half2 scale_half2 = {scale, scale};
half zp_adjust = -scale * __short2half_rn(zp);
half2 zp_adjust2 = {zp_adjust, zp_adjust};

alignas(16) half2 results[4];
half v0 = __uint2half_rn(values_quant & 0xF);
half v1 = __uint2half_rn((values_quant >> 4) & 0xF);
results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2;

half v2 = __uint2half_rn((values_quant >> 8) & 0xF);
half v3 = __uint2half_rn((values_quant >> 12) & 0xF);
results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2;

half v4 = __uint2half_rn((values_quant >> 16) & 0xF);
half v5 = __uint2half_rn((values_quant >> 20) & 0xF);
results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2;

half v6 = __uint2half_rn((values_quant >> 24) & 0xF);
half v7 = __uint2half_rn((values_quant >> 28) & 0xF);
results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2;
*(reinterpret_cast<float4*>(output)) = *(reinterpret_cast<float4*>(results));
}

__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) {
float zp_adjust = -scale * zp;
output[0] = float(values_quant & 0xF) * scale + zp_adjust;
output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust;
output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust;
output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust;
output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust;
output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust;
output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust;
output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust;
}

template <class T>
__global__ void Dequantize4BitsKernel(
T* output,
const uint8_t* quant_data,
const T* scale_data,
const uint8_t* zero_points,
int block_size,
int blocks_per_K,
int blocks_per_threadblock,
int shift) {
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
int n_idx = block_id / blocks_per_K;
int kb_idx = block_id % blocks_per_K;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
uint8_t zp = 8;
if (zero_points) {
zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2];
zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
}

output = output + element_offset;
DequantizeEightElements(quant_value, scale, static_cast<T>(zp), output);
}

template <class T>
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2]
int k,
int n,
int block_size,
cudaStream_t stream) {
// k is padded and equal to block_per_K * block_size
ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
constexpr int element_per_thread = 8;
int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int blocks_per_K = k / block_size;
int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
int shift = static_cast<int>(log2f(float(block_size)));

Dequantize4BitsKernel<<<blocks_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
scales_data,
zero_points,
block_size,
blocks_per_K,
blocks_per_threadblock,
shift);

return Status::OK();
}

template Status Dequantize4Bits<float>(
float* output,
const uint8_t* quant_data,
const float* scales_data,
const uint8_t* zero_points,
int k,
int n,
int block_size,
cudaStream_t stream);

template Status Dequantize4Bits<half>(
half* output,
const uint8_t* quant_data,
const half* scales_data,
const uint8_t* zero_points,
int k,
int n,
int block_size,
cudaStream_t stream);


///////////////////////////////////////////////////////////////////////////////
// A more general block-wise dequantization implementation that supports
// different block sizes and block orientations (row-wise/column-wise).

template <
int Row_, ///< rows of a matrix
int Column_ ///< columns of a matrix
Expand Down Expand Up @@ -70,7 +188,6 @@ void dequantizeThread(ElementT* dst, const uint8_t* weights, const ElementT* sca
const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow;

const auto meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow;
const auto meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn;

// quantized matrix is stored in column major, packed by column
const auto q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8;
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,22 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T>
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const uint8_t* zero_points,
int k,
int n,
int block_size,
cudaStream_t stream);


/**
* @brief Dequantize a column major quantized matrix, and store the result in a column major
* matrix for use in subsequent GEMM
* @brief Dequantize a block-wise quantized matrix, and store the result in a
* column major matrix for use in subsequent GEMM. This implementation supports
* columnwise and rowwise block orientation.
* @param[out] dst pointer to the dequantized matrix, column major: [columns, rows]
* @param[in] qelements pointer to the quantized elements, column major: [columns, rows]
* @param[in] scales pointer to the scales of quantized blocks, column major layout
Expand Down
45 changes: 32 additions & 13 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class MatMulNBits final : public CudaKernel {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op,"
" additional bits support is planned.");
}

Status ComputeInternal(OpKernelContext* context) const override;
Expand All @@ -37,6 +39,7 @@ class MatMulNBits final : public CudaKernel {
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_blk_{true};
};

template <typename T>
Expand Down Expand Up @@ -77,18 +80,34 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
SafeInt<int>(GetDeviceProp().sharedMemPerBlock),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
if (!is_4bit_done) {
IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_, ctx->GetComputeStream());
int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_padded, ctx->GetComputeStream());
auto* b_data = b_data_ptr.get();
ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
SafeInt<int>(block_size_),
true,
SafeInt<int>(K_),
SafeInt<int>(N_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
if (column_wise_quant_blk_) {
// column-wise block
ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
SafeInt<int>(K_padded),
SafeInt<int>(N_),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
} else {
// row-wise block
K_padded = K_;

ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
SafeInt<int>(block_size_),
column_wise_quant_blk_,
SafeInt<int>(K_),
SafeInt<int>(N_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
}
#if 0
cudaStreamSynchronize(static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
T* b_data_cpu = new T[K_ * N_];
Expand All @@ -109,7 +128,7 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
SafeInt<int>(helper.K()),
&alpha,
reinterpret_cast<const CudaT*>(b_data),
SafeInt<int>(K_),
SafeInt<int>(K_padded),
reinterpret_cast<const CudaT*>(a_data),
helper.Lda(transa),
&zero,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,14 @@ class DequantizeInt4 : public IKernelExplorer {
}

void Run() override {
ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBlockwise4b(
ORT_THROW_IF_ERROR(contrib::cuda::Dequantize4Bits(
params_.output_,
params_.quant_,
params_.scales_,
params_.zero_points_,
32,
true,
params_.k_,
params_.n_,
32,
params_.StreamHandle()));
}

Expand Down

0 comments on commit b05046e

Please sign in to comment.