Skip to content

Commit

Permalink
[MLAS AArch64] SQNBitGemm optimization (#19272)
Browse files Browse the repository at this point in the history
1. Add support for packing 4-bit values 32 at a time for CompInt8. 32 4-bit values can fit into a single 128-bit NEON register. For CompInt8, this enables a more efficient path for block sizes greater than or equal to 32. CompFp32 seems to do better with handling 16 elements at a time, so this 32-value packing is not used there.
Pack differently based on compute type. Adjust APIs to handle this.

2. Introduce template argument for whether to handle zero-point. This results in less code for the no zero-point (symmetric) case. However, there is a binary size increase due to the additional template instantiations.
  • Loading branch information
edgchen1 authored Jan 30, 2024
1 parent 04afe77 commit c379a89
Show file tree
Hide file tree
Showing 7 changed files with 558 additions and 232 deletions.
130 changes: 93 additions & 37 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,47 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"

#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
#endif

namespace onnxruntime {
namespace contrib {

namespace {
int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) {
const auto accuracy_level = std::clamp(accuracy_level_attr,
static_cast<int64_t>(CompMostAccurate),
static_cast<int64_t>(CompLeastAccurate));

#if defined(ORT_NEURAL_SPEED)

ORT_UNUSED_PARAMETER(nbits);
ORT_UNUSED_PARAMETER(block_size);

// Neural Speed APIs already expect a minimum accuracy level so just use the given value.
return accuracy_level;

#else // defined(ORT_NEURAL_SPEED)

// Find a supported accuracy level that is not less accurate than the one given.
// CompMostAccurate is always supported with the fallback implementation.
// Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed.
int64_t effective_accuracy_level = accuracy_level;
for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(effective_accuracy_level);
if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) {
break;
}
}

return effective_accuracy_level;

#endif // defined(ORT_NEURAL_SPEED)
}
} // namespace

class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info)
Expand All @@ -24,7 +58,7 @@ class MatMulNBits final : public OpKernel {
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
accuracy_level_{info.GetAttr<int64_t>("accuracy_level")} {
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
Expand Down Expand Up @@ -58,17 +92,22 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};
#ifdef ORT_NEURAL_SPEED

#if defined(ORT_NEURAL_SPEED)

bool is_asym_{false};
bool all_constant_{false};
#endif

#endif // defined(ORT_NEURAL_SPEED)
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
#ifdef ORT_NEURAL_SPEED

#if defined(ORT_NEURAL_SPEED)

if (!all_constant_) {
return Status::OK();
}
Expand Down Expand Up @@ -116,11 +155,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
#else // defined(ORT_NEURAL_SPEED)

if (input_idx == 1) {
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
if (packed_b_size_ == 0) return Status::OK();
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get());
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
Expand All @@ -136,7 +181,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
#ifdef ORT_NEURAL_SPEED

#if defined(ORT_NEURAL_SPEED)

// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
Expand All @@ -159,6 +206,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
}

#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
}

Expand All @@ -167,8 +215,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {

const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();
#ifdef ORT_NEURAL_SPEED
if (packed_b_.get()) {

#if defined(ORT_NEURAL_SPEED)

if (packed_b_) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

MatMulComputeHelper helper;
Expand Down Expand Up @@ -234,37 +284,43 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
[](size_t offset) { return offset == 0; });

if (has_single_b_matrix && packed_b_) {
for (int64_t accuracy_level = accuracy_level_;
accuracy_level >= static_cast<int64_t>(CompMostAccurate);
--accuracy_level) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level);
if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}
if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);

if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
}

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
const void* b_data = [&]() -> const void* {
if (packed_b_) {
return packed_b_.get();
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

return Status::OK();
const Tensor* b = ctx->Input<Tensor>(1);
return b->DataRaw();
}();

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data;
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

return Status::OK();
}
}

Expand Down
16 changes: 6 additions & 10 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ typedef enum {

CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_COMPUTE_TYPE;

using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these
} MLAS_SQNBIT_GEMM_COMPUTE_TYPE;

/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
Expand Down Expand Up @@ -102,18 +100,12 @@ MlasSQNBitGemmBatch(
/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
Expand Down Expand Up @@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize(
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand All @@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize(
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] ThreadPool optional thread pool to use
Expand All @@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData(
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr
Expand Down
Loading

0 comments on commit c379a89

Please sign in to comment.