diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm index cccea70c140c8..3677fca30e307 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm @@ -35,6 +35,7 @@ GemmU8S8CopyPackAFrame STRUCT SavedXmm8 OWORD ? SavedXmm9 OWORD ? SavedXmm10 OWORD ? + SavedXmm11 OWORD ? Padding QWORD ? SavedR13 QWORD ? SavedR12 QWORD ? @@ -49,6 +50,7 @@ GemmU8S8CopyPackAFrame STRUCT PreviousP4Home QWORD ? CountK QWORD ? RowSumBuffer QWORD ? + AIsSigned QWORD ? GemmU8S8CopyPackAFrame ENDS @@ -121,6 +123,7 @@ GemmU8S8CopyPackBFrame ENDS save_xmm128 xmm8,GemmU8S8CopyPackAFrame.SavedXmm8 save_xmm128 xmm9,GemmU8S8CopyPackAFrame.SavedXmm9 save_xmm128 xmm10,GemmU8S8CopyPackAFrame.SavedXmm10 + save_xmm128 xmm11,GemmU8S8CopyPackAFrame.SavedXmm11 END_PROLOGUE @@ -135,6 +138,16 @@ GemmU8S8CopyPackBFrame ENDS vpsllw ymm9,ymm8,8 ; generate word vector [0x0100] vpor ymm9,ymm8,ymm9 ; generate word vector [0x0101] +; +; Compute bit flip vector to convert S8 to U8 +; + vpxor ymm11,ymm11,ymm11 + cmp BYTE PTR GemmU8S8CopyPackAFrame.AIsSigned[rsp],0 + jz SkipSignedBitFlipVector + vpsllw ymm11,ymm9,7 ; generate word vector [0x8080] + +SkipSignedBitFlipVector: + ; ; Compute the conditional load/store mask for an unaligned CountK. ; @@ -148,12 +161,11 @@ GemmU8S8CopyPackBFrame ENDS vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] ; -; Zero initialize the padded stack buffers. +; Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed ; - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm11 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm11 ; ; Process 4 rows of matrix A in a loop. @@ -182,6 +194,10 @@ ProcessNextColumnLoopM4: vmovdqu ymm5,YMMWORD PTR [rdx+r8] vmovdqu ymm6,YMMWORD PTR [rdx+r8*2] vmovdqu ymm7,YMMWORD PTR [rdx+r13] + vpxor ymm4,ymm4,ymm11 + vpxor ymm5,ymm5,ymm11 + vpxor ymm6,ymm6,ymm11 + vpxor ymm7,ymm7,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r11],ymm5 vmovdqu YMMWORD PTR [rcx+r11*2],ymm6 @@ -208,6 +224,10 @@ ProcessRemainingColumnsM4: vmovdqu xmm5,XMMWORD PTR [rdx+r8] vmovdqu xmm6,XMMWORD PTR [rdx+r8*2] vmovdqu xmm7,XMMWORD PTR [rdx+r13] + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vmovdqu XMMWORD PTR [rcx+r11],xmm5 vmovdqu XMMWORD PTR [rcx+r11*2],xmm6 @@ -295,6 +315,10 @@ ProcessPaddedMatrixADataM4: vmovdqu xmm6,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32] vmovdqu xmm7,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+48] lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaskmovd XMMWORD PTR [rcx+r11],xmm10,xmm5 vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 @@ -347,6 +371,7 @@ ProcessNextRowM1: ProcessNextColumnLoopM1: vmovdqu ymm4,YMMWORD PTR [rdx] + vpxor ymm4,ymm4,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -361,6 +386,7 @@ ProcessRemainingColumnsM1: test bl,16 ; (CountK & 16) != 0? jz CopyRemainingCountKLessThan16M1 vmovdqu xmm4,XMMWORD PTR [rdx] + vpxor xmm4,xmm4,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -411,6 +437,7 @@ CopyRemainingCountKLessThan2M1: ProcessPaddedMatrixADataM1: vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp] + vpxor xmm4,xmm4,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -441,6 +468,7 @@ ExitRoutine: movaps xmm8,GemmU8S8CopyPackAFrame.SavedXmm8[rsp] movaps xmm9,GemmU8S8CopyPackAFrame.SavedXmm9[rsp] movaps xmm10,GemmU8S8CopyPackAFrame.SavedXmm10[rsp] + movaps xmm11,GemmU8S8CopyPackAFrame.SavedXmm11[rsp] add rsp,(GemmU8S8CopyPackAFrame.SavedR13) BEGIN_EPILOGUE diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e1..648609cba83fe 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1081,6 +1081,8 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch{&MlasGemmQuantDispatchDefault}; + const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch{&MlasGemmQuantDispatchDefault}; #elif defined(MLAS_TARGET_ARM64) const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 72eb35c894094..388d4f9ccb401 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -399,6 +399,8 @@ Return Value: if ((Cpuid7_1[0] & 0x10) != 0) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; + this->GemmS8S8Dispatch = &MlasGemmU8S8DispatchAvx2; + this->GemmS8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 75c17a6b5a177..b34881efa0f4b 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -337,7 +337,7 @@ Return Value: // // Fixup the sign bit of the per-matrix zero point offset of matrix A if the - // kernel requires signed data. + // kernel requires opposite-signed data. // ZeroPointA = MlasGemmQuantFixupZeroPointA(ZeroPointA, Shape->AIsSigned); @@ -872,13 +872,12 @@ MlasGemmQuantGetDispatch( } #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) - if (!AIsSigned) { - if (BIsSigned) { - GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch; - } - else { - GemmQuantDispatch = GetMlasPlatform().GemmU8U8Dispatch; - } + if (AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; + } else { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; } #elif defined(MLAS_TARGET_ARM64) if(BIsSigned) { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index deec324d01401..43068cd73151c 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -38,7 +38,8 @@ extern "C" { size_t lda, size_t CountM, size_t CountK, - int32_t* RowSumBuffer + int32_t* RowSumBuffer, + bool AIsSigned ); void @@ -114,6 +115,21 @@ MlasGemmQuantTryGemvKernel( return false; } +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointA( + int32_t ZeroPointA, + bool AIsSigned + ) +{ + if (AIsSigned) { + ZeroPointA = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetAType(ZeroPointA ^ 0x80); + } + + return ZeroPointA; +} + template<> MLAS_FORCEINLINE constexpr int32_t @@ -142,8 +158,7 @@ MlasGemmQuantCopyPackA( bool AIsSigned ) { - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); + MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); } template<> diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S index 5068d41ceaf35..d54e37a9108a8 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S @@ -32,6 +32,7 @@ Abstract: .equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16 .equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24 .equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32 + .equ .LGemmU8S8CopyPackAFrame_AIsSigned, 40 // // Stack frame layout for the U8S8 CopyPackB routine. @@ -91,6 +92,16 @@ Return Value: vpsllw ymm9,ymm8,8 # generate word vector [0x0100] vpor ymm9,ymm8,ymm9 # generate word vector [0x0101] +// +// Compute bit flip vector to convert S8 to U8 +// + vpxor ymm11,ymm11,ymm11 + cmp BYTE PTR .LGemmU8S8CopyPackAFrame_AIsSigned[rsp],0 + jz .LCopyPackA.SkipSignedBitFlipVector + vpsllw ymm11,ymm9,7 # generate word vector [0x8080] + +.LCopyPackA.SkipSignedBitFlipVector: + // // Compute the conditional load/store mask for an unaligned CountK. // @@ -104,12 +115,11 @@ Return Value: vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] // -// Zero initialize the padded stack buffers. +// Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed // - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm11 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm11 // // Process 4 rows of matrix A in a loop. @@ -138,6 +148,10 @@ Return Value: vmovdqu ymm5,YMMWORD PTR [rdx+r10] vmovdqu ymm6,YMMWORD PTR [rdx+r10*2] vmovdqu ymm7,YMMWORD PTR [rdx+r13] + vpxor ymm4,ymm4,ymm11 + vpxor ymm5,ymm5,ymm11 + vpxor ymm6,ymm6,ymm11 + vpxor ymm7,ymm7,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r12],ymm5 vmovdqu YMMWORD PTR [rcx+r12*2],ymm6 @@ -164,6 +178,10 @@ Return Value: vmovdqu xmm5,XMMWORD PTR [rdx+r10] vmovdqu xmm6,XMMWORD PTR [rdx+r10*2] vmovdqu xmm7,XMMWORD PTR [rdx+r13] + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vmovdqu XMMWORD PTR [rcx+r12],xmm5 vmovdqu XMMWORD PTR [rcx+r12*2],xmm6 @@ -250,6 +268,10 @@ Return Value: vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32] vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48] lea rax,[rcx+r12*2] # compute matrix D plus 2 rows + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5 vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 @@ -302,6 +324,7 @@ Return Value: .LCopyPackA.ProcessNextColumnLoopM1: vmovdqu ymm4,YMMWORD PTR [rdx] + vpxor ymm4,ymm4,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # add words to row accumulators @@ -316,6 +339,7 @@ Return Value: test bl,16 # (CountK & 16) != 0? jz .LCopyPackA.CopyRemainingCountKLessThan16M1 vmovdqu xmm4,XMMWORD PTR [rdx] + vpxor xmm4,xmm4,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # add words to row accumulators @@ -365,6 +389,7 @@ Return Value: .LCopyPackA.ProcessPaddedMatrixADataM1: vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + vpxor xmm4,xmm4,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp index 6bb93d35357f8..12955e6f04688 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm.cpp @@ -10,6 +10,8 @@ static size_t QGemmRegistLongExecute() { count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); @@ -18,6 +20,8 @@ static size_t QGemmRegistLongExecute() { count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); } return count; @@ -32,6 +36,8 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { // QGEMM U8U8=float packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); @@ -45,11 +51,17 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - // QGEMM U8S8=float packed tests + // QGEMM S8S8=float packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM U8S8=int32_t packed tests + // QGEMM S8S8=int32_t packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + // QGEMM S8U8=float packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + // QGEMM S8U8=int32_t packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } if (GetMlasThreadPool() != nullptr) { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); @@ -58,6 +70,8 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); @@ -70,6 +84,10 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } } return count;