Skip to content

Commit

Permalink
cuda kernel ready
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Feb 23, 2024
1 parent 2e58ea2 commit de14fbc
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 72 deletions.
159 changes: 123 additions & 36 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cstdint>
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cmath>
#include <type_traits>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
Expand Down Expand Up @@ -56,89 +58,174 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f
}

template <class T>
__global__ void Dequantize4BitsKernel(
__global__ void Dequantize4BitsKernelReOrder(
T* output,
const uint8_t* quant_data,
const T* scale_data,
const uint8_t* zero_points,
const int32_t* reorder_idx,
int block_size,
int blocks_per_K,
int blocks_per_threadblock,
int total_blks,
int shift) {
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
if (block_id >= total_blks) {
int groups_per_K,
int groups_per_threadblock,
int total_groups) {
int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
if (group_id >= total_groups) {
return;
}
int n_idx = block_id / blocks_per_K;
int kb_idx = block_id % blocks_per_K;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
//T __shared__ zero_points_after_reorder[];//K

Check warning on line 75 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:75: Should have a space between // and comment [whitespace/comments] [4]
//T __shared__ scales_after_reorder[]; // K

Check warning on line 76 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:76: Should have a space between // and comment [whitespace/comments] [4]
//const int num_r_per_thread = k / 256;

Check warning on line 77 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:77: Should have a space between // and comment [whitespace/comments] [4]

const int zero_point_shape_x = (groups_per_K + 1) / 2;
const int scales_shape_x = groups_per_K;
int n_idx = group_id / scales_shape_x;
int kb_idx = group_id % scales_shape_x;
int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
T* output_i = output + element_offset;
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
for (int i = 0; i < 8; i++) {
int32_t rid = reorder_idx[kb_idx * block_size + i];
T scale = *(scale_data + n_idx * scales_shape_x + rid);
uint8_t zp = 8;
if (zero_points) {
zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
}

if constexpr (std::is_same_v<T, half>) {
T zp_adjust = -scale * __short2half_rn(zp);
output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
} else {
T zp_adjust = -scale * T(zp);
output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
}
}
}

template <class T, typename ZeroT=uint8_t>

Check warning on line 105 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing spaces around = [whitespace/operators] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:105: Missing spaces around = [whitespace/operators] [4]
__global__ void Dequantize4BitsKernel(
T* output,
const uint8_t* quant_data,
const T* scale_data,
const ZeroT* zero_points,
int block_size,
int groups_per_K,
int groups_per_threadblock,
int total_groups) {
int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
if (block_id >= total_groups) {
return;
}
const int zero_point_shape_x = (groups_per_K + 1) / 2;
const int scales_shape_x = groups_per_K;
int n_idx = block_id / scales_shape_x;
int kb_idx = block_id % scales_shape_x;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
uint8_t zp = 8;
if (zero_points) {
zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2];
zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
T zero_point_value;
if constexpr(std::is_same_v<ZeroT, uint8_t>) {
uint8_t zp = 8;
if (zero_points) {
zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2];
zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
}
zero_point_value = static_cast<T>(zp);
} else {
zero_point_value = zero_points? *(zero_points + block_id):static_cast<T>(8);
}

output = output + element_offset;
DequantizeEightElements(quant_value, scale, static_cast<T>(zp), output);
DequantizeEightElements(quant_value, scale, zero_point_value, output);
}

template <class T>
template <class T, typename ZeroT>
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2]
const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2]
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream) {
// k is padded and equal to block_per_K * block_size
ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
constexpr int element_per_thread = 8;
int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int blocks_per_K = k / block_size;
int total_blks = n * blocks_per_K;
int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
int shift = static_cast<int>(log2f(float(block_size)));

Dequantize4BitsKernel<<<blocks_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
scales_data,
zero_points,
block_size,
blocks_per_K,
blocks_per_threadblock,
total_blks,
shift);
int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int groups_per_K = k / block_size;
int total_groups = n * groups_per_K; // total elemenets in quant_data
int groups_per_grid = static_cast<int>(CeilDiv(total_groups, groups_per_threadblock));
if (!reorder_idx) {
Dequantize4BitsKernel<T, ZeroT><<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
scales_data,
zero_points,
block_size,
groups_per_K,
groups_per_threadblock,
total_groups);
} else {
//static_assert(std::is_same_v<ZeroT, uint8_t>, "ZeroT must be uint8_t");

Check warning on line 171 in onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu:171: Should have a space between // and comment [whitespace/comments] [4]
Dequantize4BitsKernelReOrder<<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
scales_data,
(const uint8_t*)zero_points,
reorder_idx,
block_size,
groups_per_K,
groups_per_threadblock,
total_groups);
}

return Status::OK();
}

template Status Dequantize4Bits<float>(
template Status Dequantize4Bits<float, uint8_t>(
float* output,
const uint8_t* quant_data,
const float* scales_data,
const uint8_t* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);

template Status Dequantize4Bits<half>(
template Status Dequantize4Bits<half, uint8_t>(
half* output,
const uint8_t* quant_data,
const half* scales_data,
const uint8_t* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);
template Status Dequantize4Bits<float, float>(
float* output,
const uint8_t* quant_data,
const float* scales_data,
const float* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);


template Status Dequantize4Bits<half, half>(
half* output,
const uint8_t* quant_data,
const half* scales_data,
const half* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);
///////////////////////////////////////////////////////////////////////////////
// A more general block-wise dequantization implementation that supports
// different block sizes and block orientations (row-wise/column-wise).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T>
template <class T, typename ZeroT>
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const uint8_t* zero_points,
const ZeroT* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);


/**
* @brief Dequantize a block-wise quantized matrix, and store the result in a
* column major matrix for use in subsequent GEMM. This implementation supports
Expand Down
66 changes: 43 additions & 23 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//

#include "matmul_nbits.h"

Check warning on line 10 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:10: Include the directory when naming header files [build/include_subdir] [4]
#include <cstdint>
#include "core/common/status.h"
#include "core/framework/float16.h"
#include "core/providers/cpu/math/matmul_helper.h"
Expand All @@ -25,11 +26,13 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* reorder_idx = ctx->Input<Tensor>(4);

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

typedef typename ToCudaType<T>::MappedType CudaT;

Expand All @@ -44,33 +47,50 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

bool is_4bit_done = TryMatMul4Bits(
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
reinterpret_cast<const CudaT*>(a_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
SafeInt<int>(helper.M()),
SafeInt<int>(helper.N()),
SafeInt<int>(helper.K()),
SafeInt<int>(block_size_),
SafeInt<int>(GetDeviceProp().sharedMemPerBlock),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
bool is_4bit_done = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<T>()) &&
TryMatMul4Bits(
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
reinterpret_cast<const CudaT*>(a_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
static_cast<const uint8_t*>(zero_points_data),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.N()),
SafeInt<int>(helper.K()),
SafeInt<int>(block_size_),
SafeInt<int>(GetDeviceProp().sharedMemPerBlock),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));

if (!is_4bit_done) {
int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_padded, ctx->GetComputeStream());
auto* b_data = b_data_ptr.get();
if (column_wise_quant_blk_) {
// column-wise block
ORT_RETURN_IF_ERROR(Dequantize4Bits(
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
SafeInt<int>(K_padded),
SafeInt<int>(N_),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
if ((zero_points && zero_points->IsDataType<T>())) {
ORT_RETURN_IF_ERROR(Dequantize4Bits(
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
(const CudaT*)zero_points_data,
reorder_idx_data,
SafeInt<int>(K_padded),
SafeInt<int>(N_),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
} else {
ORT_RETURN_IF_ERROR(Dequantize4Bits(
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
(const uint8_t*)zero_points_data,
reorder_idx_data,
SafeInt<int>(K_padded),
SafeInt<int>(N_),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
}
} else {
// row-wise block
K_padded = K_;
Expand All @@ -79,7 +99,7 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
reinterpret_cast<CudaT*>(b_data),
blob_data,
reinterpret_cast<const CudaT*>(scales_data),
zero_points_data,
(const uint8_t*)zero_points_data,
SafeInt<int>(block_size_),
column_wise_quant_blk_,
SafeInt<int>(K_),
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3368,11 +3368,12 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
.Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
.Input(4, "g_idx", "group_idx for gptq", "T2", OpSchema::Optional)
.Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.")
.TypeConstraint("T3", {"tensor(uint8)", "tensor(uint32)", "tensor(int32)", "tensor(float16)"}, "Constrain quantized zero point types to uint8/uint32/int32/float16.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.")
.TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")
.TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference

Check warning on line 3378 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3378: Lines should be <= 120 characters long [whitespace/line_length] [2]
propagateElemTypeFromInputToOutput(ctx, 0, 0);

Check warning on line 3379 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3379: Lines should be <= 120 characters long [whitespace/line_length] [2]
Expand Down
Loading

0 comments on commit de14fbc

Please sign in to comment.