diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e1..4239e2ecaeb6e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 859b7c2f560a4..ed437f20f7c2a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -409,6 +409,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 2a46d68692ddc..0ed41b150a0e5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -344,6 +344,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +template MLAS_FORCEINLINE void SQ4BitGemmKernel_CompInt8_avx2( @@ -376,7 +377,7 @@ SQ4BitGemmKernel_CompInt8_avx2( ldc ); } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -390,7 +391,7 @@ SQ4BitGemmKernel_CompInt8_avx2( ldc ); } else { - MlasQ4Int8GemmKernelBlkLen64Avx2( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -406,6 +407,7 @@ SQ4BitGemmKernel_CompInt8_avx2( } } +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( @@ -425,7 +427,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( if (QuantBZeroPoint) { if (BlkLen == 16) { } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -453,7 +455,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } else { if (BlkLen == 16) { } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -502,11 +504,66 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2( ) { if (BlkLen >= 32 && CountM == 1) { - SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); return CountM; } - SQ4BitGemmKernel_CompInt8_avx2( + SQ4BitGemmKernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( BlkLen, QuantA, QuantAScale, @@ -1246,3 +1303,22 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 6ff52be83b3e3..d1762420b3a33 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -22,11 +22,12 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, MLAS_FORCEINLINE void accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { - __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( const __m256i& av00_32_epi8, @@ -53,100 +54,62 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - bv0_32_epi8, av00_32_epi8 - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - bv1_32_epi8, av01_32_epi8 - ); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - // generating constant 1s is faster here. - // __m256i one = _mm256_set1_epi16(1); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - - - const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( - bv0_32_epi8, av10_32_epi8 - ); - const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( - bv1_32_epi8, av11_32_epi8 - ); - const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); - const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); - const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); - - __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_no_bc_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const float& scale_a00, - const float& scale_a01, - const float& scale_a10, - const float& scale_a11, - const float& scale_b0, - const float& scale_b1, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); - // low_mask = _mm256_packus_epi16(low_mask, low_mask); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - // TODO: this (the second line below) is faster and does not keep low_mask in use. - // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv1_32_epi8, bv1_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a00 * scale_b0, one_16_epi16, acc0); - accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a01 * scale_b1, one_16_epi16, acc0); - accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, scale_a10 * scale_b0, one_16_epi16, acc1); - accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, scale_a11 * scale_b1, one_16_epi16, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_no_bc_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float& scale_a_0, - const float& scale_a_1, - const float& scale_b_0, - const float& scale_b_1, - __m256& acc0 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a_0 * scale_b_0, one_16_epi16, acc0); - accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a_1 * scale_b_1, one_16_epi16, acc0); + if constexpr (vnni) { + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); + } + } else { + //{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + //} + //{ + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); + //} + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( const __m256i& av00_32_epi8, @@ -163,19 +126,33 @@ accumulate_blklen32_r1c1blk2_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } } template @@ -227,6 +204,7 @@ accumulate_blklen32_r1c1blk1_avx2( } } +template MLAS_FORCEINLINE void Q4Int8Gemm2x4x2BlkLen32Avx2( const std::byte* QuantA, @@ -291,18 +269,18 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); } // increment block pointers @@ -325,28 +303,28 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -371,6 +349,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( } } +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, @@ -423,7 +402,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); - accumulate_blklen32_r2c1blk2_avx2( + accumulate_blklen32_r2c1blk2_avx2( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); @@ -443,7 +422,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); } *SumPtr = hsum_float_8(acc0); @@ -463,6 +442,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( const std::byte* QuantA, @@ -518,23 +498,23 @@ Q4Int8GemmXx4BlkLen32Avx2( const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); { - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] ); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] ); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] ); @@ -556,22 +536,22 @@ Q4Int8GemmXx4BlkLen32Avx2( { // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); } { // Col1 const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); } { // Col2 const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); } } @@ -591,6 +571,7 @@ Q4Int8GemmXx4BlkLen32Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( const std::byte* QuantA, @@ -639,7 +620,7 @@ Q4Int8GemmXxXBlkLen32Avx2( for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 ); @@ -656,7 +637,7 @@ Q4Int8GemmXxXBlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); } *SumPtr = hsum_float_8(acc0); @@ -673,6 +654,7 @@ Q4Int8GemmXxXBlkLen32Avx2( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen32Avx2( @@ -707,7 +689,7 @@ MLAS_FORCEINLINE size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8Gemm2x4x2BlkLen32Avx2( + Q4Int8Gemm2x4x2BlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -721,7 +703,7 @@ MLAS_FORCEINLINE ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8Gemm2xXBlkLen32Avx2( + Q4Int8Gemm2xXBlkLen32Avx2( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, @@ -735,7 +717,7 @@ MLAS_FORCEINLINE } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmXx4BlkLen32Avx2( + Q4Int8GemmXx4BlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, @@ -749,7 +731,7 @@ MLAS_FORCEINLINE } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmXxXBlkLen32Avx2( + Q4Int8GemmXxXBlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 15898a719632b..ec9a338348c88 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -6,6 +6,7 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx2( const __m256i& av00_32_epi8, @@ -25,32 +26,53 @@ accumulate_blklen64_r2c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); - __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); - __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); - dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); - dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); - sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + + } else { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + } } +template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx2( const __m256i& av00_32_epi8, @@ -67,20 +89,32 @@ accumulate_blklen64_r1c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } else { + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } } +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen64Avx2( const size_t BlkLen, @@ -138,10 +172,10 @@ Q4Int8GemmR2xC4BlkLen64Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; @@ -170,6 +204,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( const size_t BlkLen, @@ -223,7 +258,7 @@ Q4Int8GemmR2xC1BlkLen64Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers QuantAPtr += SubblkLen; @@ -249,6 +284,7 @@ Q4Int8GemmR2xC1BlkLen64Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( const size_t BlkLen, @@ -298,10 +334,10 @@ Q4Int8GemmR1xC4BlkLen64Avx2( for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; @@ -327,6 +363,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( const size_t BlkLen, @@ -376,7 +413,7 @@ Q4Int8GemmR1xC1BlkLen64Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - accumulate_blklen64_r1c1blk1_avx2( + accumulate_blklen64_r1c1blk1_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 ); @@ -402,6 +439,7 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx2( const size_t BlkLen, @@ -432,7 +470,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen64Avx2( + Q4Int8GemmR2xC4BlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -447,7 +485,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen64Avx2( + Q4Int8GemmR2xC1BlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -462,7 +500,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen64Avx2( + Q4Int8GemmR1xC4BlkLen64Avx2( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, @@ -477,7 +515,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen64Avx2( + Q4Int8GemmR1xC1BlkLen64Avx2( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 836c22f2cb826..0476766c23518 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -6,7 +6,7 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" -template +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_zp_avx2( const __m256i& av_32_epi8, @@ -24,13 +24,20 @@ accumulate_blklen32_r1c1blk1_zp_avx2( bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + if constexpr (vnni) { + const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } else { + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_avx2( const __m256i& av0_32_epi8, @@ -47,29 +54,48 @@ accumulate_blklen32_r1c1blk2_zp_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 - { - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } + if constexpr (vnni) { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } - { - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_is_8_avx2( const __m256i& av0_32_epi8, @@ -96,24 +122,39 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps( - _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) - ); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( const __m256i& av0_32_epi8, @@ -136,25 +177,40 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + if constexpr (vnni) { + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } } } -template +template MLAS_FORCEINLINE void Q4Int8GemmM1C4BlkLen32Avx2( const std::byte* QuantA, @@ -217,17 +273,17 @@ Q4Int8GemmM1C4BlkLen32Avx2( //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); } else { const __m256i bzp8 = _mm256_set1_epi8(8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk2; @@ -248,19 +304,19 @@ Q4Int8GemmM1C4BlkLen32Avx2( const float& scale_a00 = *QuantAScalePtr; { const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); } } @@ -283,7 +339,7 @@ Q4Int8GemmM1C4BlkLen32Avx2( } } -template +template MLAS_FORCEINLINE void Q4Int8GemmM1C1BlkLen32Avx2( const std::byte* QuantA, @@ -336,9 +392,9 @@ Q4Int8GemmM1C1BlkLen32Avx2( //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); } else { - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); } // increment block pointers @@ -356,7 +412,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); } *SumPtr = hsum_float_8(acc0); @@ -376,7 +432,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( } } -template +template MLAS_FORCEINLINE void MlasQ4Int8GemmM1KernelBlkLen32Avx2( @@ -403,7 +459,7 @@ MlasQ4Int8GemmM1KernelBlkLen32Avx2( size_t multipleCols = CountN - remainingCols; if (multipleCols > 0) { - Q4Int8GemmM1C4BlkLen32Avx2( + Q4Int8GemmM1C4BlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -416,7 +472,7 @@ MlasQ4Int8GemmM1KernelBlkLen32Avx2( } if (remainingCols > 0) { - Q4Int8GemmM1C1BlkLen32Avx2( + Q4Int8GemmM1C1BlkLen32Avx2( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData,