Skip to content

Commit

Permalink
change zp to 4bits
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Oct 9, 2023
1 parent 0a50f64 commit a1977f8
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 102 deletions.
59 changes: 43 additions & 16 deletions onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include "blockwise_quant_block.h"

#include <vector>

#include "core/framework/float16.h"
#include "core/platform/threadpool.h"
#include <iostream>

namespace onnxruntime {
namespace contrib {
Expand All @@ -15,32 +18,49 @@ template <typename T, int32_t block_size, int32_t bits>
void QuantizeBlockwise(
uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ]
const T* src, // shape: [K, N]
T* scale, // shape: [N, block_per_K]
uint8_t* zero_points, // shape: [N, block_per_K]
T* scale, // shape: [N * block_per_K]
uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2]
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
BlockwiseQuantBlock<T, block_size, bits>* dst_blob =
reinterpret_cast<BlockwiseQuantBlock<T, block_size, bits>*>(dst);

int32_t block_per_K = (K + block_size - 1) / block_size;
int32_t task_count = N * block_per_K;
int32_t total_block_count = N * block_per_K;

std::vector<uint8_t> zero_points_tmp; // to avoid race condition
(void)zero_points_tmp;
uint8_t* zero_points_tmp_ptr = zero_points;
if (bits <= 4 && zero_points != nullptr) {
zero_points_tmp.resize(total_block_count, 0);
zero_points_tmp_ptr = zero_points_tmp.data();
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool,
task_count,
[&](ptrdiff_t task_idx) {
int32_t n = static_cast<int32_t>(task_idx / block_per_K);
int32_t k_block_idx = static_cast<int32_t>(task_idx % block_per_K);
total_block_count,
[&](ptrdiff_t block_idx) {
int32_t n = static_cast<int32_t>(block_idx / block_per_K);
int32_t k_block_idx = static_cast<int32_t>(block_idx % block_per_K);
int32_t k = k_block_idx * block_size;
BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = dst_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->quant(src + k * N + n, scale[task_idx], zero_points[task_idx], k, K, N);
BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = dst_blob + block_idx;
if (nullptr != zero_points_tmp_ptr) {
blob_ptr->quant(src + k * N + n, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N);
} else {
blob_ptr->quant(src + k * N + n, scale[task_idx], k, K, N);
blob_ptr->quant(src + k * N + n, scale[block_idx], k, K, N);
}
},
0);

if (bits <= 4 && zero_points != nullptr) { // compact zero points
for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) {
zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4));
}
if (total_block_count & 1) {
zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] &0xf0) | zero_points_tmp[total_block_count - 1];
}
}
}

#define QuantizeBlockwise4Bits(block_size) \
Expand Down Expand Up @@ -78,10 +98,10 @@ void QuantizeBlockwise(

template <typename T, int32_t block_size, int32_t bits>
void DequantizeBlockwise(
T* dst, // [N, K]
const uint8_t* src, // [N, block_per_K, block_blob_size]
const T* scale, // [N, block_per_K]
const uint8_t* zero_points, // [N, block_per_K]
T* dst, // shape: [N, K]
const uint8_t* src, // shape: [N, block_per_K, block_blob_size]
const T* scale, // shape: [N, block_per_K]
const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2]
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
Expand All @@ -100,7 +120,14 @@ void DequantizeBlockwise(
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
// if bits >= 4
if constexpr (bits > 4) { // zero point is stored with a byte
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else { // zero points is stored with 4bits
uint8_t zp = zero_points[task_idx / 2];
zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf);
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zp, k, K);
}
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);
}
Expand Down
24 changes: 14 additions & 10 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,37 @@ __global__ void Dequantize4BitsKernel(
const T* scale_data,
const uint8_t* zero_points,
int block_size,
int blocks_per_tb,
int blocks_per_threadblock,
int shift) {
int block_id = blockIdx.x * blocks_per_tb + ((threadIdx.x * 8)>>shift);
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1<<shift) - 1));
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
T zero_point = static_cast<T>(zero_points ? zero_points[block_id] : (uint8_t)(8));
uint8_t zp = 8;
if (zero_points) {
zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f);
}

output = output + element_offset;
DequantizeEightElements(quant_value, scale, zero_point, output);
DequantizeEightElements(quant_value, scale, static_cast<T>(zp), output);
}

template <class T>
Status Dequantize4Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const uint8_t* zero_points,
const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2]
int k,
int n,
int block_size,
cudaStream_t stream) {
// k is padded and equal to block_per_K * block_size
ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
constexpr int element_per_thread = 8;
int blocks_per_tb = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int k_blocks = k / block_size;
int blocks_per_grid = static_cast<int>(CeilDiv(n * k_blocks, blocks_per_tb));
int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int blocks_per_K = k / block_size;
int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
int shift = static_cast<int>(log2f(float(block_size)));

Dequantize4BitsKernel<<<blocks_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
Expand All @@ -85,7 +89,7 @@ Status Dequantize4Bits(
scales_data,
zero_points,
block_size,
blocks_per_tb,
blocks_per_threadblock,
shift);

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant,

constexpr int BLOCKSIZEN = 8;

template <class T, int group_size>
template <class T, int block_size>
__global__ void MatMulFloatInt4Kernel(
T* output,
const T* a_data,
Expand All @@ -77,39 +77,43 @@ __global__ void MatMulFloatInt4Kernel(
int lane_id = threadIdx.x;
int warp_id = threadIdx.y;
int n_id = n_block_id * BLOCKSIZEN + warp_id;
int group_count = (k + group_size - 1) / group_size;
int blocks_per_K = (k + block_size - 1) / block_size;
int thread_id = warp_id * 32 + lane_id;
int k_iter = k / 256;

extern __shared__ char shared_buffer[];

// load scale to shared buffer
T* b_scale_vec = (T*)shared_buffer;
uint8_t* b_zp_vec = reinterpret_cast<uint8_t*>(b_scale_vec + BLOCKSIZEN * group_count);
int offset = n_block_id * BLOCKSIZEN * group_count;
for (int i = thread_id; i < BLOCKSIZEN * group_count; i += 256) {
uint8_t* b_zp_vec = reinterpret_cast<uint8_t*>(b_scale_vec + BLOCKSIZEN * blocks_per_K);
int offset = n_block_id * BLOCKSIZEN * blocks_per_K;
for (int i = thread_id; i < BLOCKSIZEN * blocks_per_K; i += 256) {
b_scale_vec[i] = scales_data[offset + i];
b_zp_vec[i] = zero_points != nullptr ? zero_points[offset + i] : uint8_t(8);
}
for (int i = thread_id; i < BLOCKSIZEN * blocks_per_K / 2; i += 256) {
b_zp_vec[i] = zero_points != nullptr ? zero_points[offset/2 + i] : uint8_t(0x88);
}
__syncthreads();

a_data += m_id * k;
b_data_quant += n_id * group_count * (group_size / 2);
b_data_quant += n_id * blocks_per_K * (block_size / 2);

float sum = 0.f;
int k_id = 0;
for (; k_id < (k & 0xffffff00); k_id += 256) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + (k_id >> 1) + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size;
T scale = b_scale_vec[block_idx];
uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx/2] >> 4) : (b_zp_vec[block_idx/2] & 0x0f);
sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
}

// handle reminder
if (k_id + lane_id * 8 < k) {
uint32_t value = *(reinterpret_cast<const uint32_t*>(b_data_quant + k_iter * 128 + lane_id * 4));
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size;
T scale = b_scale_vec[block_idx];
uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx/2] >> 4) : (b_zp_vec[block_idx/2] & 0x0f);
sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3));
}

Expand All @@ -133,29 +137,29 @@ bool TryMatMul4Bits(
int m,
int n,
int k,
int group_size,
int block_size,
cudaStream_t stream) {
if (n % BLOCKSIZEN != 0 || k % 8 != 0 || m > 1) {
return false;
}
dim3 blocks((n + BLOCKSIZEN - 1) / BLOCKSIZEN, m);
dim3 threads(32, 8);
int shared_mem_size = (sizeof(T) + 1) * ((k + group_size - 1) / group_size * 8);
int shared_mem_size = (sizeof(T) + 1) * ((k + block_size - 1) / block_size * 8);

if (16 == group_size) {
if (16 == block_size) {
MatMulFloatInt4Kernel<T, 16><<<blocks, threads, shared_mem_size, stream>>>(
output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
} else if (32 == group_size) {
} else if (32 == block_size) {
MatMulFloatInt4Kernel<T, 32><<<blocks, threads, shared_mem_size, stream>>>(
output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
} else if (64 == group_size) {
} else if (64 == block_size) {
MatMulFloatInt4Kernel<T, 64><<<blocks, threads, shared_mem_size, stream>>>(
output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
} else if (128 == group_size) {
} else if (128 == block_size) {
MatMulFloatInt4Kernel<T, 128><<<blocks, threads, shared_mem_size, stream>>>(
output, a_data, b_data_quant, scales_data, zero_points, m, n, k);
} else {
ORT_THROW("block size ", group_size, " is not supported");
ORT_THROW("block size ", block_size, " is not supported");
}

return true;
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3026,22 +3026,18 @@ struct Blob {
- shape: [n_cols, n_blocks_per_col, blob_size]
- type: uint8_t
scales:
- shape: [n_cols, n_blocks_per_col]
- shape: [n_cols * n_blocks_per_col]
- type: float32 or float16. Same as input A
zero_points
- shape: [n_cols, (n_blocks_per_col * 4 + 4) / 8] for nbits <= 4 and [n_cols, n_blocks_per_col] for nbits > 4
- shape: [(n_cols * n_blocks_per_col + 1) / 2] for nbits <= 4 and [n_cols * n_blocks_per_col] for nbits > 4
- type: uint8_t
)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
.SetDomain(kMSDomain)
.SinceVersion(1)
<<<<<<< HEAD
.SetDoc(MatMulWithCompressWeight_ver1_doc)
=======
.SetDoc(MatMulNBits_ver1_doc)
>>>>>>> change matmul 4bits name
.Attr("K", "size of each input feature", AttributeProto::INT)
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_quant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void QuantizeMatMulNBitsBlockwise(
py::array_t<uint8_t> dst, // shape: [ N, block_per_K, block_blob_size ]
py::array_t<T> src, // shape: [K, N]
py::array_t<T> scale, // shape: [N, block_per_K]
py::array_t<uint8_t> zero_points, // shape: [N, block_per_K]
py::array_t<uint8_t> zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2]
int32_t block_size,
int32_t N,
int32_t K,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ struct MatrixFloatInt4Params :
template <typename T>
class MatrixFloatInt4 : public IKernelExplorer {
public:
MatrixFloatInt4(DeviceArray& output, DeviceArray& a, DeviceArray& b, DeviceArray& scales, int m, int n, int k) {
MatrixFloatInt4(DeviceArray& output,
DeviceArray& a,
DeviceArray& b,
DeviceArray& scales,
int m, int n, int k) {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
params_.output_ = static_cast<T*>(output.ptr());
Expand All @@ -63,6 +67,15 @@ class MatrixFloatInt4 : public IKernelExplorer {
params_.k_ = k;
}

MatrixFloatInt4(DeviceArray& output,
DeviceArray& a,
DeviceArray& b,
DeviceArray& scales,
DeviceArray& zeropoints,
int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) {
params_.zero_points_ = static_cast<uint8_t*>(zeropoints.ptr());
}

void Run() override {
contrib::cuda::TryMatMul4Bits<T>(
params_.output_,
Expand All @@ -83,11 +96,12 @@ class MatrixFloatInt4 : public IKernelExplorer {
ParamsT params_{};
};

#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int>()) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run);

KE_REGISTER(m) {
Expand Down
Loading

0 comments on commit a1977f8

Please sign in to comment.