From a605e6a74f30c3c08b9ae61d8b36b9b776fee825 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Sun, 31 Mar 2024 23:02:06 +0000 Subject: [PATCH] avx512 port from q4gemm - pass M1N1K1 w/o bias Signed-off-by: liqunfu --- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 163 +++++++++++++++++- 1 file changed, 161 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index a63e4f5cb9449..dcbee4b1e85ef 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -20,6 +20,7 @@ Module Name: #include #include "sqnbitgemm.h" +#include "q4Common.h" // // Quantized B data packing function implementation. @@ -146,6 +147,34 @@ UnrolledLoop(IterationFn&& f) namespace { + /** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 + FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) +{ + __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); + __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); + __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); + __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); + + __m512 acc_lo0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); + __m512 acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + + __m256 acc_y = + _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); + return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); +} template MLAS_FORCEINLINE void @@ -163,8 +192,138 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const float* BiasPtr ) { - BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK; - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr; + constexpr size_t BlkBitWidth = 4; + //constexpr size_t SubBlkLen = 16; + + const __m256i lowMask = _mm256_set1_epi8(0xF); + + __m512 acc_lo[NCols]; + UnrolledLoop([&](size_t i) { + acc_lo[i] = _mm512_setzero_ps(); + }); + + const auto* b = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + __m128i* bptr[NCols]; + UnrolledLoop([&](size_t i) { + bptr[i] = (__m128i*)(b + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{ 0x0F }); + offset[i] = std::to_integer(zp); + }); + } + + // TODO: block size shall be multiple of 16 but MLAS_QUANT4_BLK_UNIT is 32 + // follwing code compute 32 float as once so lets not to use SubBlkLen(16) + // code copied from MlasQ4GemmKernelAvx512f in q4gemm_avx512.cc which only works + // with the NCols==4 case. + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); + + // Load A row vectors + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), ARowPtr + k + kk + 16); + + // Load B col vectors + __m128i bvi4[NCols]; + UnrolledLoop([&](size_t i) { + bvi4[i] = _mm_loadu_si128(bptr[i]++); + }); + + // expand 4b into byte array + __m256i bytes[NCols]; + UnrolledLoop([&](size_t i) { + bytes[i] = _mm256_set_m128i(_mm_srli_epi16(bvi4[i], 4), bvi4[i]); + bytes[i] = _mm256_and_si256(lowMask, bytes[i]); + }); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + UnrolledLoop([&](size_t i) { + bytes[i] = _mm256_sub_epi8(bytes[i], _mm256_set1_epi8(offset[i])); + }); + } + else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + UnrolledLoop([&](size_t i) { + bytes[i] = _mm256_sub_epi8(bytes[i], eight); + }); + } + + // TODO: converting, scale and fma in one unroll loop + // Convert to 16-bit int + __m256i vx16_lo[NCols], vx16_hi[NCols]; + UnrolledLoop([&](size_t i) { + vx16_lo[i] = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes[i], 0)); + vx16_hi[i] = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes[i], 1)); + }); + + __m512 bvf_lo[NCols], bvf_hi[NCols]; + UnrolledLoop([&](size_t i) { + // Convert to 32-bit int -> float 32 + bvf_lo[i] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo[i])); + bvf_hi[i] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi[i])); + + // multiply by scale + __m512 s = _mm512_set1_ps(scale_v[i]); + bvf_lo[i] = _mm512_mul_ps(bvf_lo[i], s); + bvf_hi[i] = _mm512_mul_ps(bvf_hi[i], s); + + // c[m,n] += a[m,k] * b[k,n] + acc_lo[i] = _mm512_fmadd_ps(bvf_lo[i], av_lo, acc_lo[i]); + acc_lo[i] = _mm512_fmadd_ps(bvf_hi[i], av_hi, acc_lo[i]); + }); + } + + b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc_lo[0], acc_lo[1], acc_lo[2], acc_lo[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = _mm512_reduce_add_ps(acc_lo[i]); + SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i]; + } + } } template