Skip to content

Commit

Permalink
[ARM] MatMulNBits fp16 support - connect kernels (#22856)
Browse files Browse the repository at this point in the history
### Description
A breakdown PR of #22651



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored and guschmue committed Dec 2, 2024
1 parent 36ff99d commit db76bb4
Show file tree
Hide file tree
Showing 17 changed files with 850 additions and 121 deletions.
8 changes: 4 additions & 4 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
Expand Down Expand Up @@ -363,8 +363,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Module Name:

#include "fp16_common.h"
#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
{
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down Expand Up @@ -560,9 +561,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;

// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
94 changes: 91 additions & 3 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ MlasIsQNBitGemmAvailable(
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case HQNBitGemmVariant_BitWidth4_CompFp16: {
return Dispatch->HQ4BitGemmPackQuantBData != nullptr &&
Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr &&
Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8
return
Expand Down Expand Up @@ -253,6 +258,16 @@ MlasQNBitGemmPackQuantBData(
packed_quant_b,
ThreadPool
);
} else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) {
Dispatch->HQ4BitGemmPackQuantBData(
N,
K,
BlkLen,
ComputeType,
static_cast<const std::byte*>(QuantBData),
static_cast<std::byte*>(PackedQuantBDataAndOrBlkSumWorkspace),
ThreadPool
);
} else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) {
// TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests.
//assert(QuantBScale == nullptr);
Expand Down Expand Up @@ -387,7 +402,7 @@ SQ4BitGemm_CompFp32(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
Expand Down Expand Up @@ -419,6 +434,79 @@ SQ4BitGemm_CompFp32(
}
}

void
HQ4BitGemm_CompFp16(
const size_t BlkLen,
const size_t K,
const MLAS_QNBIT_GEMM_DATA_PARAMS<MLAS_FP16>* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);

const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
const size_t k_blk_num = MlasDivRoundup(K, BlkLen);
const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t ldb = k_blk_num * BlkLen;
const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blk_num);

const MLAS_FP16* A = DataParams->A + RangeStartM * lda;
MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->PackedQuantBData) + RangeStartN * qldb;
const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes;
const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias;

// 32N is the sweet spot of cache utilization. It is machine dependent though.
constexpr size_t StrideM = 2;
constexpr size_t StrideN = 32;

// TODO(fajin): move allocation up to the op.
size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16);
MlasThreadedBufAlloc(bufsize);
auto* dequant_b = reinterpret_cast<MLAS_FP16*>(ThreadedBufHolder.get());

for (size_t n = 0, countN; n < RangeCountN; n += countN) {
countN = std::min(StrideN, RangeCountN - n);
GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16(
BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num
);

const MLAS_FP16* a = A;
MLAS_FP16* c = C;
for (size_t m = 0, countM; m < RangeCountM; m += countM) {
countM = std::min(StrideM, RangeCountM - m);
GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16(
a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc
);
}

a += countM * lda;
c += countM * ldc;
}

QuantBData += countN * qldb;
QuantBScale += countN * k_blk_num;
QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr;
Bias = Bias ? Bias + countN : nullptr;
C += countN;
}
}

void
SQ4BitGemm_CompInt8(
const size_t BlkLen,
Expand Down Expand Up @@ -720,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant)
{
switch (variant) {
case HQNBitGemmVariant_BitWidth4_CompFp16:
return nullptr;
return HQ4BitGemm_CompFp16;
default:
return nullptr;
}
Expand Down
81 changes: 72 additions & 9 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
//

/** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;
Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;

/** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */
typedef void(SQ4BitGemmPackQuantBData_Fn)(
typedef void(Q4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
Expand All @@ -111,7 +111,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_THREADPOOL* ThreadPool
);

SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr;

typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)(
size_t N,
Expand Down Expand Up @@ -142,28 +143,28 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;
Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
*
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)(
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;
Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// SQNBIT_CompFp32 kernel function prototypes.
Expand Down Expand Up @@ -229,7 +230,38 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
size_t BlockStrideQuantB
);

Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr;
Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr;

/**
* @brief Dequantize B into the format expected by the Sgemm kernel.
* B is a quantized 4-bit integer matrix that is block quantized and column major.
* This is equivalent to dequantizing B and then running MlasSgemmCopyPackB.
*
* @param BlkLen Number of values in a block.
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
* It should have enough space for
* (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen
* elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are
* useful, but the kernel implementation can be simplified with the extra space.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param CountN Number of columns of B.
* @param CountK Number of rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
*/
typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)(
size_t BlkLen,
MLAS_FP16* FpData,
const std::byte* QuantBData,
const MLAS_FP16* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
);

Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr;

//
// SQNBIT_CompInt8 kernel function prototypes.
Expand Down Expand Up @@ -338,4 +370,35 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
float* AScaledGroupSum // scale_k * Sum_blklen(a_i)
);
QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr;

/**
* @brief Multiply fp16 matrix A rows with fp16 matrix B columns.
* Results are written to fp16 matrix C.
* If bias is provided, the bias are added to the result.
*
* @param A first row of the A matrix segment. Row major.
* @param B first column of the B matrix segment. Column major.
* @param Bias the bias at the target column. Optional.
* @param[out] C first element of the output matrix segment. Row major.
* @param CountM the number of rows of A chunk.
* @param CountN the number of columns of B chunk.
* @param K the number of columns of A matrix and rows of B matrix.
* @param lda the leading dimension of A.
* @param ldb the leading dimension of B.
* @param ldc the leading dimension of C.
*/
typedef void(HQ4BitGemmKernel_CompFp16_Fn)(
const MLAS_FP16* A,
const MLAS_FP16* B,
const MLAS_FP16* Bias,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t K,
size_t lda,
size_t ldb,
size_t ldc
);

HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr;
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Licensed under the MIT License.
Module Name:
sqnbitgemm_kernel_neon.cpp
qnbitgemm_kernel_neon.cpp
Abstract:
Expand All @@ -20,7 +20,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down Expand Up @@ -185,10 +185,17 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32;

d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32;
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) {
d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
}
d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8;

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16;
d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16;
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64

return d;
}();
Loading

0 comments on commit db76bb4

Please sign in to comment.