From 72e49b722bf0ea28cc4cdffef5464ed45c328876 Mon Sep 17 00:00:00 2001 From: Michael Guynn Date: Mon, 29 Apr 2024 11:35:41 -0700 Subject: [PATCH 1/4] Implementation of sign flipping in QGemm CopyPackA to enable S8S8 and S8U8 handling in AVX2 and AVX-VNNI. Added corresponding unit testing. --- .../mlas/lib/amd64/QgemmU8S8KernelAvx2.asm | 36 +++++++++++++-- onnxruntime/core/mlas/lib/mlasi.h | 2 + onnxruntime/core/mlas/lib/platform.cpp | 2 + onnxruntime/core/mlas/lib/qgemm.h | 15 +++---- .../core/mlas/lib/qgemm_kernel_avx2.cpp | 21 +++++++-- .../mlas/lib/x86_64/QgemmU8S8KernelAvx2.S | 33 ++++++++++++-- onnxruntime/test/mlas/unittest/test_qgemm.cpp | 44 +++++++++++++++---- 7 files changed, 126 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm index cccea70c140c8..3677fca30e307 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm @@ -35,6 +35,7 @@ GemmU8S8CopyPackAFrame STRUCT SavedXmm8 OWORD ? SavedXmm9 OWORD ? SavedXmm10 OWORD ? + SavedXmm11 OWORD ? Padding QWORD ? SavedR13 QWORD ? SavedR12 QWORD ? @@ -49,6 +50,7 @@ GemmU8S8CopyPackAFrame STRUCT PreviousP4Home QWORD ? CountK QWORD ? RowSumBuffer QWORD ? + AIsSigned QWORD ? GemmU8S8CopyPackAFrame ENDS @@ -121,6 +123,7 @@ GemmU8S8CopyPackBFrame ENDS save_xmm128 xmm8,GemmU8S8CopyPackAFrame.SavedXmm8 save_xmm128 xmm9,GemmU8S8CopyPackAFrame.SavedXmm9 save_xmm128 xmm10,GemmU8S8CopyPackAFrame.SavedXmm10 + save_xmm128 xmm11,GemmU8S8CopyPackAFrame.SavedXmm11 END_PROLOGUE @@ -135,6 +138,16 @@ GemmU8S8CopyPackBFrame ENDS vpsllw ymm9,ymm8,8 ; generate word vector [0x0100] vpor ymm9,ymm8,ymm9 ; generate word vector [0x0101] +; +; Compute bit flip vector to convert S8 to U8 +; + vpxor ymm11,ymm11,ymm11 + cmp BYTE PTR GemmU8S8CopyPackAFrame.AIsSigned[rsp],0 + jz SkipSignedBitFlipVector + vpsllw ymm11,ymm9,7 ; generate word vector [0x8080] + +SkipSignedBitFlipVector: + ; ; Compute the conditional load/store mask for an unaligned CountK. ; @@ -148,12 +161,11 @@ GemmU8S8CopyPackBFrame ENDS vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] ; -; Zero initialize the padded stack buffers. +; Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed ; - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp],ymm11 + vmovdqu YMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm11 ; ; Process 4 rows of matrix A in a loop. @@ -182,6 +194,10 @@ ProcessNextColumnLoopM4: vmovdqu ymm5,YMMWORD PTR [rdx+r8] vmovdqu ymm6,YMMWORD PTR [rdx+r8*2] vmovdqu ymm7,YMMWORD PTR [rdx+r13] + vpxor ymm4,ymm4,ymm11 + vpxor ymm5,ymm5,ymm11 + vpxor ymm6,ymm6,ymm11 + vpxor ymm7,ymm7,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r11],ymm5 vmovdqu YMMWORD PTR [rcx+r11*2],ymm6 @@ -208,6 +224,10 @@ ProcessRemainingColumnsM4: vmovdqu xmm5,XMMWORD PTR [rdx+r8] vmovdqu xmm6,XMMWORD PTR [rdx+r8*2] vmovdqu xmm7,XMMWORD PTR [rdx+r13] + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vmovdqu XMMWORD PTR [rcx+r11],xmm5 vmovdqu XMMWORD PTR [rcx+r11*2],xmm6 @@ -295,6 +315,10 @@ ProcessPaddedMatrixADataM4: vmovdqu xmm6,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+32] vmovdqu xmm7,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp+48] lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaskmovd XMMWORD PTR [rcx+r11],xmm10,xmm5 vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 @@ -347,6 +371,7 @@ ProcessNextRowM1: ProcessNextColumnLoopM1: vmovdqu ymm4,YMMWORD PTR [rdx] + vpxor ymm4,ymm4,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -361,6 +386,7 @@ ProcessRemainingColumnsM1: test bl,16 ; (CountK & 16) != 0? jz CopyRemainingCountKLessThan16M1 vmovdqu xmm4,XMMWORD PTR [rdx] + vpxor xmm4,xmm4,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -411,6 +437,7 @@ CopyRemainingCountKLessThan2M1: ProcessPaddedMatrixADataM1: vmovdqu xmm4,XMMWORD PTR GemmU8S8CopyPackAFrame.PaddedMatrixAData[rsp] + vpxor xmm4,xmm4,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators @@ -441,6 +468,7 @@ ExitRoutine: movaps xmm8,GemmU8S8CopyPackAFrame.SavedXmm8[rsp] movaps xmm9,GemmU8S8CopyPackAFrame.SavedXmm9[rsp] movaps xmm10,GemmU8S8CopyPackAFrame.SavedXmm10[rsp] + movaps xmm11,GemmU8S8CopyPackAFrame.SavedXmm11[rsp] add rsp,(GemmU8S8CopyPackAFrame.SavedR13) BEGIN_EPILOGUE diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e1..648609cba83fe 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1081,6 +1081,8 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch{&MlasGemmQuantDispatchDefault}; + const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch{&MlasGemmQuantDispatchDefault}; #elif defined(MLAS_TARGET_ARM64) const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 72eb35c894094..a317242a67e0f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -357,6 +357,8 @@ Return Value: this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx2; + this->GemmS8S8Dispatch = &MlasGemmU8S8DispatchAvx2; + this->GemmS8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmFloatKernel = MlasGemmFloatKernelFma3; this->GemmDoubleKernel = MlasGemmDoubleKernelFma3; diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 75c17a6b5a177..b34881efa0f4b 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -337,7 +337,7 @@ Return Value: // // Fixup the sign bit of the per-matrix zero point offset of matrix A if the - // kernel requires signed data. + // kernel requires opposite-signed data. // ZeroPointA = MlasGemmQuantFixupZeroPointA(ZeroPointA, Shape->AIsSigned); @@ -872,13 +872,12 @@ MlasGemmQuantGetDispatch( } #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) - if (!AIsSigned) { - if (BIsSigned) { - GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch; - } - else { - GemmQuantDispatch = GetMlasPlatform().GemmU8U8Dispatch; - } + if (AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; + } else { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; } #elif defined(MLAS_TARGET_ARM64) if(BIsSigned) { diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index deec324d01401..43068cd73151c 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -38,7 +38,8 @@ extern "C" { size_t lda, size_t CountM, size_t CountK, - int32_t* RowSumBuffer + int32_t* RowSumBuffer, + bool AIsSigned ); void @@ -114,6 +115,21 @@ MlasGemmQuantTryGemvKernel( return false; } +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointA( + int32_t ZeroPointA, + bool AIsSigned + ) +{ + if (AIsSigned) { + ZeroPointA = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetAType(ZeroPointA ^ 0x80); + } + + return ZeroPointA; +} + template<> MLAS_FORCEINLINE constexpr int32_t @@ -142,8 +158,7 @@ MlasGemmQuantCopyPackA( bool AIsSigned ) { - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); + MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); } template<> diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S index 5068d41ceaf35..d54e37a9108a8 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S @@ -32,6 +32,7 @@ Abstract: .equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16 .equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24 .equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32 + .equ .LGemmU8S8CopyPackAFrame_AIsSigned, 40 // // Stack frame layout for the U8S8 CopyPackB routine. @@ -91,6 +92,16 @@ Return Value: vpsllw ymm9,ymm8,8 # generate word vector [0x0100] vpor ymm9,ymm8,ymm9 # generate word vector [0x0101] +// +// Compute bit flip vector to convert S8 to U8 +// + vpxor ymm11,ymm11,ymm11 + cmp BYTE PTR .LGemmU8S8CopyPackAFrame_AIsSigned[rsp],0 + jz .LCopyPackA.SkipSignedBitFlipVector + vpsllw ymm11,ymm9,7 # generate word vector [0x8080] + +.LCopyPackA.SkipSignedBitFlipVector: + // // Compute the conditional load/store mask for an unaligned CountK. // @@ -104,12 +115,11 @@ Return Value: vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] // -// Zero initialize the padded stack buffers. +// Initialize the padded stack buffers. Zeroed if unsigned, bit-flip values if signed // - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp],ymm11 + vmovdqu YMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm11 // // Process 4 rows of matrix A in a loop. @@ -138,6 +148,10 @@ Return Value: vmovdqu ymm5,YMMWORD PTR [rdx+r10] vmovdqu ymm6,YMMWORD PTR [rdx+r10*2] vmovdqu ymm7,YMMWORD PTR [rdx+r13] + vpxor ymm4,ymm4,ymm11 + vpxor ymm5,ymm5,ymm11 + vpxor ymm6,ymm6,ymm11 + vpxor ymm7,ymm7,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vmovdqu YMMWORD PTR [rcx+r12],ymm5 vmovdqu YMMWORD PTR [rcx+r12*2],ymm6 @@ -164,6 +178,10 @@ Return Value: vmovdqu xmm5,XMMWORD PTR [rdx+r10] vmovdqu xmm6,XMMWORD PTR [rdx+r10*2] vmovdqu xmm7,XMMWORD PTR [rdx+r13] + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vmovdqu XMMWORD PTR [rcx+r12],xmm5 vmovdqu XMMWORD PTR [rcx+r12*2],xmm6 @@ -250,6 +268,10 @@ Return Value: vmovdqu xmm6,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+32] vmovdqu xmm7,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp+48] lea rax,[rcx+r12*2] # compute matrix D plus 2 rows + vpxor xmm4,xmm4,xmm11 + vpxor xmm5,xmm5,xmm11 + vpxor xmm6,xmm6,xmm11 + vpxor xmm7,xmm7,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5 vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 @@ -302,6 +324,7 @@ Return Value: .LCopyPackA.ProcessNextColumnLoopM1: vmovdqu ymm4,YMMWORD PTR [rdx] + vpxor ymm4,ymm4,ymm11 vmovdqu YMMWORD PTR [rcx],ymm4 vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # add words to row accumulators @@ -316,6 +339,7 @@ Return Value: test bl,16 # (CountK & 16) != 0? jz .LCopyPackA.CopyRemainingCountKLessThan16M1 vmovdqu xmm4,XMMWORD PTR [rdx] + vpxor xmm4,xmm4,xmm11 vmovdqu XMMWORD PTR [rcx],xmm4 vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # add words to row accumulators @@ -365,6 +389,7 @@ Return Value: .LCopyPackA.ProcessPaddedMatrixADataM1: vmovdqu xmm4,XMMWORD PTR .LGemmU8S8CopyPackAFrame_PaddedMatrixAData[rsp] + vpxor xmm4,xmm4,xmm11 vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp index 6bb93d35357f8..068cd44af245c 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm.cpp @@ -10,6 +10,8 @@ static size_t QGemmRegistLongExecute() { count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); @@ -18,6 +20,8 @@ static size_t QGemmRegistLongExecute() { count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); } return count; @@ -32,6 +36,8 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { // QGEMM U8U8=float packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); @@ -44,11 +50,22 @@ static size_t QGemmRegistShortExecute() { // QGEMM U8S8=int32_t packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - // QGEMM U8S8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM U8S8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + try { + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { + // QGEMM S8S8=float packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + // QGEMM S8S8=int32_t packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + // QGEMM S8U8=float packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + // QGEMM S8U8=int32_t packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + } catch (const std::invalid_argument& e) { + (void)e; + // no support for these types on this device, ignore and continue } if (GetMlasThreadPool() != nullptr) { @@ -58,6 +75,8 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); @@ -66,9 +85,18 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + try { + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + } catch (const std::invalid_argument& e) { + (void)e; + // no support for these types on this device, ignore and continue } } From a032de7dc00544a96a7f213d257ab06432bcd2d9 Mon Sep 17 00:00:00 2001 From: Michael Guynn Date: Wed, 26 Jun 2024 18:46:23 -0700 Subject: [PATCH 2/4] Removed exception handling in unit tests --- onnxruntime/test/mlas/unittest/test_qgemm.cpp | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp index 068cd44af245c..6fb9a4c692a19 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm.cpp @@ -50,22 +50,17 @@ static size_t QGemmRegistShortExecute() { // QGEMM U8S8=int32_t packed tests count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } - try { - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - // QGEMM S8S8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM S8S8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { - // QGEMM S8U8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM S8U8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - } catch (const std::invalid_argument& e) { - (void)e; - // no support for these types on this device, ignore and continue + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { + // QGEMM S8S8=float packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + // QGEMM S8S8=int32_t packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + // QGEMM S8U8=float packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + // QGEMM S8U8=int32_t packed tests + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } if (GetMlasThreadPool() != nullptr) { @@ -85,18 +80,13 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } - try { - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - } catch (const std::invalid_argument& e) { - (void)e; - // no support for these types on this device, ignore and continue + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } } From 40e378382b5c8c3ffa1d7d2dec8bfa419f4c4086 Mon Sep 17 00:00:00 2001 From: Michael Guynn Date: Mon, 1 Jul 2024 11:15:06 -0700 Subject: [PATCH 3/4] Fixed typo in unit test --- onnxruntime/test/mlas/unittest/test_qgemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp index 6fb9a4c692a19..12955e6f04688 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_qgemm.cpp @@ -85,8 +85,8 @@ static size_t QGemmRegistShortExecute() { count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); + count += QgemmShortExecuteTest::RegisterShortExecuteTests(); } } From ae16d59130dfcea90cb51e34163b406a003d4c2b Mon Sep 17 00:00:00 2001 From: Michael Guynn Date: Mon, 8 Jul 2024 10:45:52 -0700 Subject: [PATCH 4/4] Moved S8S8 and S8U8 support to AVX-VNNI only. AVX2 will keep default C++ implementation --- onnxruntime/core/mlas/lib/platform.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index a317242a67e0f..388d4f9ccb401 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -357,8 +357,6 @@ Return Value: this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx2; - this->GemmS8S8Dispatch = &MlasGemmU8S8DispatchAvx2; - this->GemmS8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmFloatKernel = MlasGemmFloatKernelFma3; this->GemmDoubleKernel = MlasGemmDoubleKernelFma3; @@ -401,6 +399,8 @@ Return Value: if ((Cpuid7_1[0] & 0x10) != 0) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; + this->GemmS8S8Dispatch = &MlasGemmU8S8DispatchAvx2; + this->GemmS8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;