Skip to content

Commit

Permalink
MLAS AArch64 quantized int4 Gemm kernel (#18031)
Browse files Browse the repository at this point in the history
- Implement MLAS function for quantized 4-bit int Gemm (Gemm with float A and quantized 4-bit int B) for ARM NEON. This is an initial implementation. Only the M=1 path (with M being number of rows of A and C) has any optimization attempted so far. More optimization to come in future PRs.

- Connect MatMulNBits contrib op to MLAS function.
  • Loading branch information
edgchen1 authored Nov 15, 2023
1 parent 586f06f commit 0a4d76d
Show file tree
Hide file tree
Showing 28 changed files with 1,744 additions and 205 deletions.
3 changes: 3 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.cpp
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -68,6 +69,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -334,6 +336,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
)
if (NOT APPLE)
set(mlas_platform_srcs
Expand Down
68 changes: 34 additions & 34 deletions include/onnxruntime/core/framework/op_node_proto_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@
#include "core/common/gsl.h"
#endif

#ifdef __has_attribute
#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
#else
#define ORT_HAVE_ATTRIBUTE(x) 0
#endif

#if ORT_HAVE_ATTRIBUTE(nodiscard)
#define MUST_USE_RESULT [[nodiscard]]
#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result)
#define MUST_USE_RESULT __attribute__((warn_unused_result))
#else
#define MUST_USE_RESULT
#endif

class IMLOpKernel;

namespace onnxruntime {
Expand All @@ -43,14 +29,26 @@ class OpNodeProtoHelper {
Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
*/
template <typename T>
MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const;
Status GetAttr(const std::string& name, T* value) const;

/**
Get a single attribute
Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
Throws if an attribute with the specified type doesn't exist
*/
template <typename T>
[[nodiscard]] T GetAttr(const std::string& name) const {
T value;
ORT_THROW_IF_ERROR(GetAttr(name, &value));
return value;
}

/**
Get a single attribute
Call this function only when a default value for an optional attribute isn't specified in the op schema
*/
template <typename T>
T GetAttrOrDefault(const std::string& name, const T& default_value) const {
[[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const {
T tmp;
return GetAttr<T>(name, &tmp).IsOK() ? tmp : default_value;
}
Expand All @@ -70,7 +68,8 @@ class OpNodeProtoHelper {
Call this function only when a default value for an optional attribute isn't specified in the op schema
*/
template <typename T>
MUST_USE_RESULT std::vector<T> GetAttrsOrDefault(const std::string& name, const std::vector<T>& default_value = std::vector<T>{}) const {
[[nodiscard]] std::vector<T> GetAttrsOrDefault(const std::string& name,
const std::vector<T>& default_value = {}) const {
std::vector<T> tmp;
return GetAttrs<T>(name, tmp).IsOK() ? tmp : default_value;
}
Expand All @@ -87,11 +86,12 @@ class OpNodeProtoHelper {
/// <param name="values">Attribute data in a span, out parameter</param>
/// <returns>Status</returns>
template <typename T>
MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;
Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;

MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const;
Status GetAttrs(const std::string& name, TensorShapeVector& out) const;

MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const {
[[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name,
const TensorShapeVector& default_value = {}) const {
TensorShapeVector tmp;
return GetAttrs(name, tmp).IsOK() ? tmp : default_value;
}
Expand All @@ -100,43 +100,43 @@ class OpNodeProtoHelper {
Get repeated attributes
*/
template <typename T>
MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector<T>& values) const;
Status GetAttrs(const std::string& name, std::vector<T>& values) const;

template <typename T>
MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span<T> values) const;
Status GetAttrs(const std::string& name, gsl::span<T> values) const;

MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name,
std::vector<std::reference_wrapper<const std::string>>& refs) const;
Status GetAttrsStringRefs(const std::string& name,
std::vector<std::reference_wrapper<const std::string>>& refs) const;

uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
const std::string& name) const noexcept;
[[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
const std::string& name) const noexcept;

bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
const std::string& name) const noexcept;
[[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
const std::string& name) const noexcept;

uint32_t GetInputCount() const {
[[nodiscard]] uint32_t GetInputCount() const {
return gsl::narrow_cast<uint32_t>(impl_->getNumInputs());
}

uint32_t GetOutputCount() const {
[[nodiscard]] uint32_t GetOutputCount() const {
return gsl::narrow_cast<uint32_t>(impl_->getNumOutputs());
}

const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
[[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
return impl_->getInputType(index);
}

const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
[[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
// Work around lack of a const method from the onnx InferenceContext interface
return const_cast<Impl_t*>(impl_)->getOutputType(index);
}

// Try to query an attribute, returning nullptr if it doesn't exist
const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
[[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
return impl_->getAttribute(name);
}

const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
[[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name);
ORT_ENFORCE(attr != nullptr);
return attr;
Expand Down
111 changes: 71 additions & 40 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
#include "core/mlas/inc/mlas.h"
#include "core/mlas/inc/mlas_qnbit.h"
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace contrib {

class MatMulNBits final : public OpKernel {
public:
MatMulNBits(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
MatMulNBits(const OpKernelInfo& info)
: OpKernel(info),
K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op,"
" additional bits support is planned.");
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
}

Status Compute(OpKernelContext* context) const override;

private:
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_{true};
const size_t K_;
const size_t N_;
const size_t block_size_;
const size_t nbits_;
const bool column_wise_quant_{true};
};

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
Expand All @@ -45,11 +48,60 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();

TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));

Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();

auto* y_data = y->MutableData<float>();

const size_t batch_count = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) {
// number of bytes or elements between adjacent matrices
size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes;
MlasBlockwiseQuantizedBufferSizes(static_cast<int>(nbits_), static_cast<int>(block_size_), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
b_data_matrix_stride_in_bytes, b_scale_matrix_stride,
&b_zero_point_matrix_stride_in_bytes);

const size_t b_matrix_size = K * N;

InlinedVector<MLAS_SQNBIT_GEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; ++i) {
const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size;

data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes;
data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride;
data[i].QuantBZeroPoint = zero_points_data != nullptr
? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes
: nullptr;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool);

return Status::OK();
}

const size_t ldb = helper.Ldb(true);

AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);

// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
Expand All @@ -67,29 +119,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
#endif

TensorShape b_shape({N_, K_});

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));

Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();

auto* y_data = y->MutableData<float>();

const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
const size_t ldb = helper.Ldb(true);

// TODO: implement with native kernel
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
for (size_t i = 0; i < max_len; i++) {
std::vector<MLAS_SGEMM_DATA_PARAMS> data(batch_count);
for (size_t i = 0; i < batch_count; i++) {
data[i].BIsPacked = false;
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
Expand All @@ -101,7 +132,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
data[i].beta = 0.0f;
}
MlasGemmBatch(CblasNoTrans, CblasTrans,
M, N, K, data.data(), max_len, thread_pool);
M, N, K, data.data(), batch_count, thread_pool);

return Status::OK();
}
Expand Down
Loading

0 comments on commit 0a4d76d

Please sign in to comment.