From 4677e7ad8b1a7053ecc2d8366d6177f64594f13d Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 29 Apr 2024 14:44:52 +0800 Subject: [PATCH] pop vnni flags --- bestla/bestla/kernel_avx2.h | 15 +++++++++------ bestla/bestla/kernel_avx512f.h | 4 +++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index f2a6b3623..4ef7443b7 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -40,7 +40,7 @@ static inline __m256i unpack_4bits(void* srcptr, __m256i mask) { } static inline __m256i unpack_2bits(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y, - const __m256i& vsfhl_mask_y, const __m256i& vorder_y) { + const __m256i& vsfhl_mask_y, const __m256i& vorder_y) { auto vraw_x = _mm_loadl_epi64((const __m128i*)ptr); auto vsrc_y = _mm256_broadcastq_epi64(vraw_x); auto vordered_y = _mm256_permutevar8x32_epi32(vsrc_y, vorder_y); @@ -51,7 +51,7 @@ static inline __m256i unpack_2bits(utils::bit2x4* ptr, const __m256i& vshift_y, } static inline __m256i unpack_1bits(utils::bit1x8* ptr, const __m256i& bit1Shift_1, const __m256i& bit1Mask, - const __m256i& bit1Shift_2, const __m256i& highMask) { + const __m256i& bit1Shift_2, const __m256i& highMask) { auto bit1x32 = _mm256_set1_epi32(*(int*)ptr); bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); @@ -159,7 +159,7 @@ static inline void dequant_s8_N_avx2(DstT* dstptr, int8_t* srcptr, __m256* vscal static inline __m256i load_zp_epi8_broadcast_epi16_v16(int8_t* zpptr, const __m256i& vindex) { auto v_zp_x = _mm_loadu_si128((const __m128i*)zpptr); auto v_zp_y = _mm256_cvtepi8_epi16(v_zp_x); - auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex); + auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex); return v_zp_y_cast; } @@ -1192,7 +1192,6 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, } } - template static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { static_assert(N % 2 == 0); @@ -1266,7 +1265,6 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst return BTLA_CODE::Success; } - template static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, @@ -2332,6 +2330,10 @@ constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_ load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>, load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>}; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + template static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc, __m256* vsca) { if constexpr (MTILE == 1) { @@ -3539,7 +3541,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut } #ifdef __GNUC__ -#pragma GCC diagnostic pop +#pragma GCC pop_options +#else #endif } // namespace vnni diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index cff3b9449..3085fa8db 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -4337,8 +4337,10 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut return BTLA_CODE::Success; } + #ifdef __GNUC__ -#pragma GCC diagnostic pop +#pragma GCC pop_options +#else #endif } // namespace vnni