Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-wise 4b quantization matmul operator change #18172

Merged
merged 16 commits into from
Nov 3, 2023
129 changes: 0 additions & 129 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h

This file was deleted.

174 changes: 0 additions & 174 deletions onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

This file was deleted.

25 changes: 14 additions & 11 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "dequantize_blockwise.h"
#include "core/mlas/inc/mlas.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -18,6 +17,7 @@
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planed.");

Check warning on line 20 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:20: Lines should be <= 120 characters long [whitespace/line_length] [2]
chenfucn marked this conversation as resolved.
Show resolved Hide resolved
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -27,6 +27,7 @@
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_{true};
};

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
Expand All @@ -46,15 +47,17 @@
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
DequantizeBlockwise<float>(tmp_b_data_ptr.get(),
b_data,
scales_data,
zero_points_data,
static_cast<int32_t>(block_size_),
static_cast<int32_t>(nbits_),
static_cast<int32_t>(N_),
static_cast<int32_t>(K_),
thread_pool);

chenfucn marked this conversation as resolved.
Show resolved Hide resolved
MlasDequantizeBlockwise<float>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
zero_points_data, // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);

#if 0 // for debug
auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,19 @@ __global__ void Dequantize4BitsKernel(
const T* scale_data,
const uint8_t* zero_points,
int block_size,
int blocks_per_K,
int blocks_per_threadblock,
int shift) {
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
int n_idx = block_id / blocks_per_K;
int kb_idx = block_id % blocks_per_K;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
uint8_t zp = 8;
if (zero_points) {
zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f);
zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2];
zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
}

output = output + element_offset;
Expand Down Expand Up @@ -100,6 +104,7 @@ Status Dequantize4Bits(
scales_data,
zero_points,
block_size,
blocks_per_K,
blocks_per_threadblock,
shift);

Expand Down
Loading
Loading