Skip to content

Commit

Permalink
avxvnni
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Jul 20, 2024
1 parent 6654d22 commit 4b91bed
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 245 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Return Value:
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
}

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
88 changes: 82 additions & 6 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2(
}
}

template<bool vnni>
MLAS_FORCEINLINE
void
SQ4BitGemmKernel_CompInt8_avx2(
Expand Down Expand Up @@ -376,7 +377,7 @@ SQ4BitGemmKernel_CompInt8_avx2(
ldc
);
} else if (BlkLen == 32) {
MlasQ4Int8GemmKernelBlkLen32Avx2(
MlasQ4Int8GemmKernelBlkLen32Avx2<vnni>(
QuantA,
QuantAScale,
QuantBData,
Expand All @@ -390,7 +391,7 @@ SQ4BitGemmKernel_CompInt8_avx2(
ldc
);
} else {
MlasQ4Int8GemmKernelBlkLen64Avx2(
MlasQ4Int8GemmKernelBlkLen64Avx2<vnni>(
BlkLen,
QuantA,
QuantAScale,
Expand All @@ -406,6 +407,7 @@ SQ4BitGemmKernel_CompInt8_avx2(
}
}

template<bool vnni>
MLAS_FORCEINLINE
void
SQ4BitGemmM1Kernel_CompInt8_avx2(
Expand All @@ -425,7 +427,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
if (QuantBZeroPoint) {
if (BlkLen == 16) {
} else if (BlkLen == 32) {
MlasQ4Int8GemmM1KernelBlkLen32Avx2<true>(
MlasQ4Int8GemmM1KernelBlkLen32Avx2<true, vnni>(
QuantA,
QuantAScale,
QuantBData,
Expand Down Expand Up @@ -453,7 +455,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2(
} else {
if (BlkLen == 16) {
} else if (BlkLen == 32) {
MlasQ4Int8GemmM1KernelBlkLen32Avx2<false>(
MlasQ4Int8GemmM1KernelBlkLen32Avx2<false, vnni>(
QuantA,
QuantAScale,
QuantBData,
Expand Down Expand Up @@ -502,11 +504,66 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2(
)
{
if (BlkLen >= 32 && CountM == 1) {
SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias);
SQ4BitGemmM1Kernel_CompInt8_avx2<false>(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias);
return CountM;
}

SQ4BitGemmKernel_CompInt8_avx2(
SQ4BitGemmKernel_CompInt8_avx2<false>(
BlkLen,
QuantA,
QuantAScale,
QuantBData,
QuantBScale,
C,
CountM,
CountN,
CountK,
BlockCountK,
Bias,
ldc
);
float* c_blk = C;
const float* b_blk_sum = QuantBBlkSum;

size_t RowsRemaining = CountM;
const float* a_blksum_row = ABlockSum;
while (RowsRemaining > 0) {
auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false
);

c_blk += ldc * RowsHandled;
a_blksum_row += BlockCountK * RowsHandled;
RowsRemaining -= RowsHandled;
}
return CountM;
}

size_t
SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni(
const size_t BlkLen,
const std::byte* QuantA,
const float* QuantAScale,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
const float* Bias,
size_t ldc,
const float* ABlockSum,
const float* QuantBBlkSum
)
{
if (BlkLen >= 32 && CountM == 1) {
SQ4BitGemmM1Kernel_CompInt8_avx2<true>(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias);
return CountM;
}

SQ4BitGemmKernel_CompInt8_avx2<true>(
BlkLen,
QuantA,
QuantAScale,
Expand Down Expand Up @@ -1246,3 +1303,22 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {

return d;
}();

const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() {
MLAS_SQNBIT_GEMM_DISPATCH d;

d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;
d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum;

d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni;
d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2;

return d;
}();
Loading

0 comments on commit 4b91bed

Please sign in to comment.