Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLAS] AArch64 SQNBitGemm CompInt8 initial multi-row implementation #21193

Merged
merged 14 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -350,9 +353,12 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
if (NOT APPLE)
set(mlas_platform_srcs
Expand Down
46 changes: 9 additions & 37 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ MlasIsSQNBitGemmAvailable(
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr &&
return Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr &&
Dispatch->QuantizeARow_CompInt8 != nullptr;
}
default: {
Expand Down Expand Up @@ -431,36 +431,6 @@ SQ4BitGemm_CompInt8(

const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;

if (RangeCountM == 1) {
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});

const std::byte* a_row = QuantA;
const std::byte* b_col = QuantBData + n * ldb;
const float* b_col_scale = QuantBScale + n * k_blks;
const std::byte* b_col_zp =
(QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes;
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
RangeCountM, CountN, ldc
);
}
}
return;
}

// This is a naive M > 1 implementation that repeatedly calls the M=1 kernel.
// TODO Replace it with an optimized implementation.
size_t CountN;
for (size_t n = 0; n < RangeCountN; n += CountN) {
CountN = std::min(RangeCountN - n, size_t{128});
Expand All @@ -473,21 +443,23 @@ SQ4BitGemm_CompInt8(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

for (size_t m = 0; m < RangeCountM; ++m) {
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8(
for (size_t m = 0; m < RangeCountM;) {
const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8(
BlkLen,
a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias
a_row, b_col, b_col_scale, b_col_zp, c_blk, RangeCountM - m, CountN, K, k_blks, ldc, bias
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM, RangeStartN + n,
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
RangeCountM, CountN, ldc
RangeCountM - m, CountN, ldc
);
}

c_blk += ldc;
a_row += lda;
c_blk += RowsHandled * ldc;
a_row += RowsHandled * lda;

m += RowsHandled;
}
}
}
Expand Down
19 changes: 12 additions & 7 deletions onnxruntime/core/mlas/lib/sqnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,9 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
// CompInt8 kernel function prototypes.
//

/**
/**
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
* @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B.
* A and B are block quantized and B is column major.
* This kernel handles the special case where M, the number of rows of A and C, is 1.
*
* @param BlkLen Number of values in a block.
* @param QuantA Supplies the quantized A matrix.
Expand All @@ -193,25 +192,31 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param[out] C Supplies the output C matrix.
* @param CountN Number of columns of B and C.
* @param CountM Number of rows of A and C to process, an upper bound.
* @param CountN Number of columns of B and C to process.
* @param CountK Number of columns of A and rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
* @param BlockCountK Number of blocks in one row of A and one column of B.
* @param ldc Number of elements between adjacent rows of C.
* @param Bias Bias vector of length N.
*
* @return The number of rows of A and C that were processed, at most CountM.
*/
typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)(
typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
size_t BlockCountK,
size_t ldc,
const float* Bias
);

SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr;
SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr;

/**
* @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers.
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
}
}

size_t
SQ4BitGemmKernel_CompInt8_avx2(
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
size_t ldc,
const float* Bias
)
{
MLAS_UNREFERENCED_PARAMETER(ldc);

if (CountM == 0) {
return 0;
}

SQ4BitGemmM1Kernel_CompInt8_avx2(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockCountK,
Bias
);

return 1;
}

template <size_t NCols, bool HasZeroPoint>
MLAS_FORCEINLINE void
ComputeDotProducts_BlkLen16_CompFp32_avx2(
Expand Down Expand Up @@ -1109,7 +1147,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2;

return d;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2;
d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512;

return d;
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,44 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
}
}

size_t
SQ4BitGemmKernel_CompInt8_avx512vnni(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
size_t ldc,
const float* Bias
)
{
MLAS_UNREFERENCED_PARAMETER(ldc);

if (CountM == 0) {
return 0;
}

SQ4BitGemmM1Kernel_CompInt8_avx512vnni(
BlkLen,
QuantA,
QuantBData,
QuantBScale,
QuantBZeroPoint,
C,
CountN,
CountK,
BlockCountK,
Bias
);

return 1;
}

void MLASCALL
MlasQ80BlkQuantRow_avx512(
size_t BlkLen,
Expand All @@ -260,7 +298,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni;
d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512vnni;
d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512;

return d;
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,19 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2(
const size_t BlockStrideQuantB
);

void
SQ4BitGemmM1Kernel_CompInt8_avx2(
size_t
SQ4BitGemmKernel_CompInt8_avx2(
size_t BlkLen,
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB,
size_t BlockCountK,
size_t ldc,
const float* Bias
);

Expand Down
Loading
Loading