From 012e9c46848004acaf3cd0cd5435007a12d0331c Mon Sep 17 00:00:00 2001 From: liqunfu Date: Mon, 29 Jul 2024 21:27:03 +0000 Subject: [PATCH] hsum_float_16 Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 18 ------------------ .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 6b2a93604af49..dc1f4d4a5a254 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -8,24 +8,6 @@ #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" -static MLAS_FORCEINLINE __m256 -h_add_512(__m512 a) -{ - return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); -} - -static MLAS_FORCEINLINE float -hsum_float_16(const __m512 x) -{ - __m256 hi = h_add_512(x); - __m128 hi128 = _mm256_extractf128_ps(hi, 1); - __m128 lo128 = _mm256_castps256_ps128(hi); - hi128 = _mm_add_ps(hi128, lo128); - hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); - hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); - return _mm_cvtss_f32(hi128); -} - static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index dbdcc751d61a6..2a65ac4af0c1d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -6,6 +6,24 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} + static MLAS_FORCEINLINE __m512i combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) {