Skip to content

Commit

Permalink
[MLAS] AArch64 SQNBitGemm CompInt8 initial multi-row implementation (#…
Browse files Browse the repository at this point in the history
…21193)

Update AArch64 SQNBitGemm CompInt8 kernels to process matrix in tiles. E.g., computing the output in 2x2 tiles allows us to compute four elements of the output with one read of two rows of A and two columns of B.

Also moved some code around as it was getting big for a single file.
  • Loading branch information
edgchen1 authored Jul 10, 2024
1 parent 8749fa3 commit 20cd339
Show file tree
Hide file tree
Showing 12 changed files with 2,248 additions and 1,384 deletions.
8 changes: 7 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -350,9 +353,12 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
if (NOT APPLE)
set(mlas_platform_srcs
Expand Down
56 changes: 16 additions & 40 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ Module Name:
--*/

#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"

#include <cassert>

#include "sqnbitgemm_q8_block.h"

namespace
{

Expand Down Expand Up @@ -80,7 +81,7 @@ MlasIsSQNBitGemmAvailable(
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr &&
Dispatch->QuantizeARow_CompInt8 != nullptr;
}
default: {
Expand Down Expand Up @@ -372,15 +373,17 @@ SQ4BitGemm_CompFp32(
if (bias) {
AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
}

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN,
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
RowsHandled, CountN, ldc
);
}

c_blk += ldc * RowsHandled;
a_row += lda * RowsHandled;

RowsRemaining -= RowsHandled;
}
}
Expand Down Expand Up @@ -431,36 +434,6 @@ SQ4BitGemm_CompInt8(

const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;

if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});

const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}

// This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
// TODO Replace it with an optimized implementation.
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
Expand All @@ -473,21 +446,24 @@ SQ4BitGemm_CompInt8(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

for (size_t m = 0; m < RangeCountM; ++m) {
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
size_t RowsRemaining = RangeCountM;
while (RowsRemaining > 0) {
const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n,
RowsHandled, CountN, ldc
);
}

c_blk += ldc;
a_row += lda;
c_blk += RowsHandled * ldc;
a_row += RowsHandled * lda;

RowsRemaining -= RowsHandled;
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions onnxruntime/core/mlas/lib/sqnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
/**
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
* A and B are block quantized and B is column major.
* This kernel handles the special case where M, the number of rows of A and C, is 1.
*
* @param BlkLen Number of values in a block.
* @param QuantA Supplies the quantized A matrix.
Expand All @@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountN Number of columns of B and C.
* @param CountM Number of rows of A and C to process, an upper bound.
* @param CountN Number of columns of B and C to process.
* @param CountK Number of columns of A and rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
* @param BlockCountK Number of blocks in one row of A and one column of B.
* @param ldc Number of elements between adjacent rows of C.
* @param Bias Bias vector of length N.
*
* @return The number of rows of A and C that were processed, at most CountM.
*/
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
size_t BlockCountK,
size_t ldc,
const float* Bias
);

SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr;

/**
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
}
}

size_t
SQ4BitGemmKernel_CompInt8_avx2(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
size_t ldc,
const float* Bias
)
{
MLAS_UNREFERENCED_PARAMETER(ldc);

if (CountM == 0) {
return 0;
}

SQ4BitGemmM1Kernel_CompInt8_avx2(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockCountK,
Bias
);

return 1;
}

template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkLen16_CompFp32_avx2(
Expand Down Expand Up @@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2;

return d;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512;

return d;
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
}
}

size_t
SQ4BitGemmKernel_CompInt8_avx512vnni(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
size_t ldc,
const float* Bias
)
{
MLAS_UNREFERENCED_PARAMETER(ldc);

if (CountM == 0) {
return 0;
}

SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockCountK,
Bias
);

return 1;
}

void MLASCALL
MlasQ80BlkQuantRow_avx512(
size_t BlkLen,
Expand All @@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni;
d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512;

return d;
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2(
const size_t BlockStrideQuantB
);

void
SQ4BitGemmM1Kernel_CompInt8_avx2(
size_t
SQ4BitGemmKernel_CompInt8_avx2(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
size_t BlockCountK,
size_t ldc,
const float* Bias
);

Expand Down
Loading

0 comments on commit 20cd339

Please sign in to comment.