Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling S8S8 and S8U8 handling in QGemm for AVX2 and AVX-VNNI #21123

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ GemmU8S8CopyPackAFrame STRUCT
SavedXmm8 OWORD ?
SavedXmm9 OWORD ?
SavedXmm10 OWORD ?
SavedXmm11 OWORD ?
Padding QWORD ?
SavedR13 QWORD ?
SavedR12 QWORD ?
Expand All @@ -49,6 +50,7 @@ GemmU8S8CopyPackAFrame STRUCT
PreviousP4Home QWORD ?
CountK QWORD ?
RowSumBuffer QWORD ?
AIsSigned QWORD ?

GemmU8S8CopyPackAFrame ENDS

Expand Down Expand Up @@ -121,6 +123,7 @@ GemmU8S8CopyPackBFrame ENDS
save_xmm128 xmm8,GemmU8S8CopyPackAFrame.SavedXmm8
save_xmm128 xmm9,GemmU8S8CopyPackAFrame.SavedXmm9
save_xmm128 xmm10,GemmU8S8CopyPackAFrame.SavedXmm10
save_xmm128 xmm11,GemmU8S8CopyPackAFrame.SavedXmm11

END_PROLOGUE

Expand All @@ -135,6 +138,16 @@ GemmU8S8CopyPackBFrame ENDS
vpsllw ymm9,ymm8,8 ; generate word vector [0x0100]
vpor ymm9,ymm8,ymm9 ; generate word vector [0x0101]

;
; Compute bit flip vector to convert S8 to U8
;
vpxor ymm11,ymm11,ymm11
cmp BYTE PTR GemmU8S8CopyPackAFrame.AIsSigned[rsp],0
jz SkipSignedBitFlipVector
vpsllw ymm11,ymm9,7 ; generate word vector [0x8080]

SkipSignedBitFlipVector:

;
; Compute the conditional load/store mask for an unaligned CountK.
;
Expand All @@ -148,12 +161,11 @@ GemmU8S8CopyPackBFrame ENDS
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]

;
; Zero initialize the padded stack buffers.
; Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed
;

vpxor xmm0,xmm0,xmm0
vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0
vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0
vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm11
vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm11

;
; Process 4 rows of matrix A in a loop.
Expand Down Expand Up @@ -182,6 +194,10 @@ ProcessNextColumnLoopM4:
vmovdqu ymm5,YMMWORD PTR [rdx+r8]
vmovdqu ymm6,YMMWORD PTR [rdx+r8*2]
vmovdqu ymm7,YMMWORD PTR [rdx+r13]
vpxor ymm4,ymm4,ymm11
vpxor ymm5,ymm5,ymm11
vpxor ymm6,ymm6,ymm11
vpxor ymm7,ymm7,ymm11
vmovdqu YMMWORD PTR [rcx],ymm4
vmovdqu YMMWORD PTR [rcx+r11],ymm5
vmovdqu YMMWORD PTR [rcx+r11*2],ymm6
Expand All @@ -208,6 +224,10 @@ ProcessRemainingColumnsM4:
vmovdqu xmm5,XMMWORD PTR [rdx+r8]
vmovdqu xmm6,XMMWORD PTR [rdx+r8*2]
vmovdqu xmm7,XMMWORD PTR [rdx+r13]
vpxor xmm4,xmm4,xmm11
vpxor xmm5,xmm5,xmm11
vpxor xmm6,xmm6,xmm11
vpxor xmm7,xmm7,xmm11
vmovdqu XMMWORD PTR [rcx],xmm4
vmovdqu XMMWORD PTR [rcx+r11],xmm5
vmovdqu XMMWORD PTR [rcx+r11*2],xmm6
Expand Down Expand Up @@ -295,6 +315,10 @@ ProcessPaddedMatrixADataM4:
vmovdqu xmm6,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32]
vmovdqu xmm7,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+48]
lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows
vpxor xmm4,xmm4,xmm11
vpxor xmm5,xmm5,xmm11
vpxor xmm6,xmm6,xmm11
vpxor xmm7,xmm7,xmm11
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaskmovd XMMWORD PTR [rcx+r11],xmm10,xmm5
vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6
Expand Down Expand Up @@ -347,6 +371,7 @@ ProcessNextRowM1:

ProcessNextColumnLoopM1:
vmovdqu ymm4,YMMWORD PTR [rdx]
vpxor ymm4,ymm4,ymm11
vmovdqu YMMWORD PTR [rcx],ymm4
vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators
Expand All @@ -361,6 +386,7 @@ ProcessRemainingColumnsM1:
test bl,16 ; (CountK & 16) != 0?
jz CopyRemainingCountKLessThan16M1
vmovdqu xmm4,XMMWORD PTR [rdx]
vpxor xmm4,xmm4,xmm11
vmovdqu XMMWORD PTR [rcx],xmm4
vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators
Expand Down Expand Up @@ -411,6 +437,7 @@ CopyRemainingCountKLessThan2M1:

ProcessPaddedMatrixADataM1:
vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp]
vpxor xmm4,xmm4,xmm11
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators
Expand Down Expand Up @@ -441,6 +468,7 @@ ExitRoutine:
movaps xmm8,GemmU8S8CopyPackAFrame.SavedXmm8[rsp]
movaps xmm9,GemmU8S8CopyPackAFrame.SavedXmm9[rsp]
movaps xmm10,GemmU8S8CopyPackAFrame.SavedXmm10[rsp]
movaps xmm11,GemmU8S8CopyPackAFrame.SavedXmm11[rsp]
add rsp,(GemmU8S8CopyPackAFrame.SavedR13)

BEGIN_EPILOGUE
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,8 @@ struct MLAS_PLATFORM {
#if defined(MLAS_TARGET_AMD64_IX86)
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch{&MlasGemmQuantDispatchDefault};
const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch{&MlasGemmQuantDispatchDefault};
#elif defined(MLAS_TARGET_ARM64)
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ Return Value:
if ((Cpuid7_1[0] & 0x10) != 0) {

this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmS8S8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmS8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
Expand Down
15 changes: 7 additions & 8 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ Return Value:

//
// Fixup the sign bit of the per-matrix zero point offset of matrix A if the
// kernel requires signed data.
// kernel requires opposite-signed data.
//

ZeroPointA = MlasGemmQuantFixupZeroPointA<KernelType>(ZeroPointA, Shape->AIsSigned);
Expand Down Expand Up @@ -872,13 +872,12 @@ MlasGemmQuantGetDispatch(
}

#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64)
if (!AIsSigned) {
if (BIsSigned) {
GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch;
}
else {
GemmQuantDispatch = GetMlasPlatform().GemmU8U8Dispatch;
}
if (AIsSigned) {
GemmQuantDispatch =
BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch;
} else {
GemmQuantDispatch =
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
}
#elif defined(MLAS_TARGET_ARM64)
if(BIsSigned) {
Expand Down
21 changes: 18 additions & 3 deletions onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ extern "C" {
size_t lda,
size_t CountM,
size_t CountK,
int32_t* RowSumBuffer
int32_t* RowSumBuffer,
bool AIsSigned
);

void
Expand Down Expand Up @@ -114,6 +115,21 @@ MlasGemmQuantTryGemvKernel<MLAS_GEMM_U8S8_KERNEL_AVX2>(
return false;
}

template<>
MLAS_FORCEINLINE constexpr
int32_t
MlasGemmQuantFixupZeroPointA<MLAS_GEMM_U8S8_KERNEL_AVX2>(
int32_t ZeroPointA,
bool AIsSigned
)
{
if (AIsSigned) {
ZeroPointA = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetAType(ZeroPointA ^ 0x80);
}

return ZeroPointA;
}

template<>
MLAS_FORCEINLINE constexpr
int32_t
Expand Down Expand Up @@ -142,8 +158,7 @@ MlasGemmQuantCopyPackA<MLAS_GEMM_U8S8_KERNEL_AVX2>(
bool AIsSigned
)
{
MLAS_UNREFERENCED_PARAMETER(AIsSigned);
MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer);
MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned);
}

template<>
Expand Down
33 changes: 29 additions & 4 deletions onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Abstract:
.equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16
.equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24
.equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32
.equ .LGemmU8S8CopyPackAFrame_AIsSigned, 40

//
// Stack frame layout for the U8S8 CopyPackB routine.
Expand Down Expand Up @@ -91,6 +92,16 @@ Return Value:
vpsllw ymm9,ymm8,8 # generate word vector [0x0100]
vpor ymm9,ymm8,ymm9 # generate word vector [0x0101]

//
// Compute bit flip vector to convert S8 to U8
//
vpxor ymm11,ymm11,ymm11
cmp BYTE PTR .LGemmU8S8CopyPackAFrame_AIsSigned[rsp],0
jz .LCopyPackA.SkipSignedBitFlipVector
vpsllw ymm11,ymm9,7 # generate word vector [0x8080]

.LCopyPackA.SkipSignedBitFlipVector:

//
// Compute the conditional load/store mask for an unaligned CountK.
//
Expand All @@ -104,12 +115,11 @@ Return Value:
vmovdqu xmm10,XMMWORD PTR [rbx+rax*4]

//
// Zero initialize the padded stack buffers.
// Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed
//

vpxor xmm0,xmm0,xmm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm11
vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm11

//
// Process 4 rows of matrix A in a loop.
Expand Down Expand Up @@ -138,6 +148,10 @@ Return Value:
vmovdqu ymm5,YMMWORD PTR [rdx+r10]
vmovdqu ymm6,YMMWORD PTR [rdx+r10*2]
vmovdqu ymm7,YMMWORD PTR [rdx+r13]
vpxor ymm4,ymm4,ymm11
vpxor ymm5,ymm5,ymm11
vpxor ymm6,ymm6,ymm11
vpxor ymm7,ymm7,ymm11
vmovdqu YMMWORD PTR [rcx],ymm4
vmovdqu YMMWORD PTR [rcx+r12],ymm5
vmovdqu YMMWORD PTR [rcx+r12*2],ymm6
Expand All @@ -164,6 +178,10 @@ Return Value:
vmovdqu xmm5,XMMWORD PTR [rdx+r10]
vmovdqu xmm6,XMMWORD PTR [rdx+r10*2]
vmovdqu xmm7,XMMWORD PTR [rdx+r13]
vpxor xmm4,xmm4,xmm11
vpxor xmm5,xmm5,xmm11
vpxor xmm6,xmm6,xmm11
vpxor xmm7,xmm7,xmm11
vmovdqu XMMWORD PTR [rcx],xmm4
vmovdqu XMMWORD PTR [rcx+r12],xmm5
vmovdqu XMMWORD PTR [rcx+r12*2],xmm6
Expand Down Expand Up @@ -250,6 +268,10 @@ Return Value:
vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32]
vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48]
lea rax,[rcx+r12*2] # compute matrix D plus 2 rows
vpxor xmm4,xmm4,xmm11
vpxor xmm5,xmm5,xmm11
vpxor xmm6,xmm6,xmm11
vpxor xmm7,xmm7,xmm11
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5
vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6
Expand Down Expand Up @@ -302,6 +324,7 @@ Return Value:

.LCopyPackA.ProcessNextColumnLoopM1:
vmovdqu ymm4,YMMWORD PTR [rdx]
vpxor ymm4,ymm4,ymm11
vmovdqu YMMWORD PTR [rcx],ymm4
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
Expand All @@ -316,6 +339,7 @@ Return Value:
test bl,16 # (CountK & 16) != 0?
jz .LCopyPackA.CopyRemainingCountKLessThan16M1
vmovdqu xmm4,XMMWORD PTR [rdx]
vpxor xmm4,xmm4,xmm11
vmovdqu XMMWORD PTR [rcx],xmm4
vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # add words to row accumulators
Expand Down Expand Up @@ -365,6 +389,7 @@ Return Value:

.LCopyPackA.ProcessPaddedMatrixADataM1:
vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp]
vpxor xmm4,xmm4,xmm11
vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4
vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row
vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns
Expand Down
22 changes: 20 additions & 2 deletions onnxruntime/test/mlas/unittest/test_qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ static size_t QGemmRegistLongExecute() {
count += MlasLongExecuteTests<MlasQgemmTest<uint8_t, uint8_t, int32_t, true, false>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, int8_t, int32_t, false, false>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, int8_t, int32_t, true, false>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, uint8_t, int32_t, false, false>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, uint8_t, int32_t, true, false>>::RegisterLongExecute();

if (GetMlasThreadPool() != nullptr) {
count += MlasLongExecuteTests<MlasQgemmTest<uint8_t, int8_t, int32_t, false, true>>::RegisterLongExecute();
Expand All @@ -18,6 +20,8 @@ static size_t QGemmRegistLongExecute() {
count += MlasLongExecuteTests<MlasQgemmTest<uint8_t, uint8_t, int32_t, true, true>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, int8_t, int32_t, false, true>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, int8_t, int32_t, true, true>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, uint8_t, int32_t, false, true>>::RegisterLongExecute();
count += MlasLongExecuteTests<MlasQgemmTest<int8_t, uint8_t, int32_t, true, true>>::RegisterLongExecute();
}

return count;
Expand All @@ -32,6 +36,8 @@ static size_t QGemmRegistShortExecute() {
count += QgemmShortExecuteTest<uint8_t, uint8_t, int32_t, false, false>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, int8_t, float, false, false>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, int8_t, int32_t, false, false>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, uint8_t, float, false, false>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, uint8_t, int32_t, false, false>::RegisterShortExecuteTests();
if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) {
// QGEMM U8U8=float packed tests
count += QgemmShortExecuteTest<uint8_t, uint8_t, float, true, false>::RegisterShortExecuteTests();
Expand All @@ -45,11 +51,17 @@ static size_t QGemmRegistShortExecute() {
count += QgemmShortExecuteTest<uint8_t, int8_t, int32_t, true, false>::RegisterShortExecuteTests();
}
if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) {
// QGEMM U8S8=float packed tests
// QGEMM S8S8=float packed tests
count += QgemmShortExecuteTest<int8_t, int8_t, float, true, false>::RegisterShortExecuteTests();
// QGEMM U8S8=int32_t packed tests
// QGEMM S8S8=int32_t packed tests
count += QgemmShortExecuteTest<int8_t, int8_t, int32_t, true, false>::RegisterShortExecuteTests();
}
if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) {
// QGEMM S8U8=float packed tests
count += QgemmShortExecuteTest<int8_t, uint8_t, float, true, false>::RegisterShortExecuteTests();
// QGEMM S8U8=int32_t packed tests
count += QgemmShortExecuteTest<int8_t, uint8_t, int32_t, true, false>::RegisterShortExecuteTests();
}

if (GetMlasThreadPool() != nullptr) {
count += QgemmShortExecuteTest<uint8_t, int8_t, float, false, true>::RegisterShortExecuteTests();
Expand All @@ -58,6 +70,8 @@ static size_t QGemmRegistShortExecute() {
count += QgemmShortExecuteTest<uint8_t, uint8_t, int32_t, false, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, int8_t, float, false, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, int8_t, int32_t, false, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, uint8_t, float, false, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, uint8_t, int32_t, false, true>::RegisterShortExecuteTests();
if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) {
count += QgemmShortExecuteTest<uint8_t, uint8_t, float, true, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<uint8_t, uint8_t, int32_t, true, true>::RegisterShortExecuteTests();
Expand All @@ -70,6 +84,10 @@ static size_t QGemmRegistShortExecute() {
count += QgemmShortExecuteTest<int8_t, int8_t, float, true, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, int8_t, int32_t, true, true>::RegisterShortExecuteTests();
}
if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) {
count += QgemmShortExecuteTest<int8_t, uint8_t, float, true, true>::RegisterShortExecuteTests();
count += QgemmShortExecuteTest<int8_t, uint8_t, int32_t, true, true>::RegisterShortExecuteTests();
}
}

return count;
Expand Down
Loading