Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-wise 4b quantization matmul operator change #18172

Merged
merged 16 commits into from
Nov 3, 2023
129 changes: 0 additions & 129 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h

This file was deleted.

174 changes: 0 additions & 174 deletions onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

This file was deleted.

28 changes: 17 additions & 11 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "dequantize_blockwise.h"
#include "core/mlas/inc/mlas.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -18,6 +17,9 @@ class MatMulNBits final : public OpKernel {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op,"
" additional bits support is planned.");
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -27,6 +29,7 @@ class MatMulNBits final : public OpKernel {
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_{true};
};

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
Expand All @@ -46,15 +49,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
DequantizeBlockwise<float>(tmp_b_data_ptr.get(),
b_data,
scales_data,
zero_points_data,
static_cast<int32_t>(block_size_),
static_cast<int32_t>(nbits_),
static_cast<int32_t>(N_),
static_cast<int32_t>(K_),
thread_pool);

chenfucn marked this conversation as resolved.
Show resolved Hide resolved
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
zero_points_data, // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);

#if 0 // for debug
auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
Expand Down
Loading
Loading