Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
pop vnni flags
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Apr 29, 2024
1 parent a6053ed commit 4677e7a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 9 additions & 6 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static inline __m256i unpack_4bits(void* srcptr, __m256i mask) {
}

static inline __m256i unpack_2bits(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y,
const __m256i& vsfhl_mask_y, const __m256i& vorder_y) {
const __m256i& vsfhl_mask_y, const __m256i& vorder_y) {
auto vraw_x = _mm_loadl_epi64((const __m128i*)ptr);
auto vsrc_y = _mm256_broadcastq_epi64(vraw_x);
auto vordered_y = _mm256_permutevar8x32_epi32(vsrc_y, vorder_y);
Expand All @@ -51,7 +51,7 @@ static inline __m256i unpack_2bits(utils::bit2x4* ptr, const __m256i& vshift_y,
}

static inline __m256i unpack_1bits(utils::bit1x8* ptr, const __m256i& bit1Shift_1, const __m256i& bit1Mask,
const __m256i& bit1Shift_2, const __m256i& highMask) {
const __m256i& bit1Shift_2, const __m256i& highMask) {
auto bit1x32 = _mm256_set1_epi32(*(int*)ptr);
bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1);
bit1x32 = _mm256_and_si256(bit1x32, bit1Mask);
Expand Down Expand Up @@ -159,7 +159,7 @@ static inline void dequant_s8_N_avx2(DstT* dstptr, int8_t* srcptr, __m256* vscal
static inline __m256i load_zp_epi8_broadcast_epi16_v16(int8_t* zpptr, const __m256i& vindex) {
auto v_zp_x = _mm_loadu_si128((const __m128i*)zpptr);
auto v_zp_y = _mm256_cvtepi8_epi16(v_zp_x);
auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex);
auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex);
return v_zp_y_cast;
}

Expand Down Expand Up @@ -1192,7 +1192,6 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales,
}
}


template <int N, BTLA_DTYPE QT_T>
static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) {
static_assert(N % 2 == 0);
Expand Down Expand Up @@ -1266,7 +1265,6 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst
return BTLA_CODE::Success;
}


template <BTLA_DTYPE QT_T, bool _IS_SYM, int _NCOL, typename _ST, typename _DST_T>
static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col,
int ld_src, int ld_dst, _ST* scales, int8_t* zero_points,
Expand Down Expand Up @@ -2332,6 +2330,10 @@ constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_
load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>,
load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>};

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

template <int MTILE, int NReg, int Unroll>
static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc, __m256* vsca) {
if constexpr (MTILE == 1) {
Expand Down Expand Up @@ -3539,7 +3541,8 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut
}

#ifdef __GNUC__
#pragma GCC diagnostic pop
#pragma GCC pop_options
#else
#endif
} // namespace vnni

Expand Down
4 changes: 3 additions & 1 deletion bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -4337,8 +4337,10 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut
return BTLA_CODE::Success;
}


#ifdef __GNUC__
#pragma GCC diagnostic pop
#pragma GCC pop_options
#else
#endif
} // namespace vnni

Expand Down

0 comments on commit 4677e7a

Please sign in to comment.