From 60621586f75d5aee11c63798c663faf15453585e Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Thu, 18 Apr 2024 16:34:42 +0800 Subject: [PATCH] use vec register instead of general register --- bestla/bestla/kernel_avx2.h | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 45890e504..0940da9c0 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -1231,16 +1231,12 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr } static inline __m256i unpack_2bits_avx2(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y, - const __m256i& vsfhl_mask_y) { - auto raw64 = *(uint64_t*)ptr; - auto rawlo32 = (raw64 & 0xffffffff) | (raw64 << 32); - auto rawhi32 = (raw64 & 0xffffffff00000000) | (raw64 >> 32); - auto vlo_x = _mm_set_epi64x(*(int64_t*)&rawlo32, *(int64_t*)&rawlo32); - auto vhi_x = _mm_set_epi64x(*(int64_t*)&rawhi32, *(int64_t*)&rawhi32); - auto vsrc_y = _mm256_set_m128i(vhi_x, vlo_x); - auto vs_y = _mm256_sllv_epi32(vsrc_y, vshift_y); + const __m256i& vsfhl_mask_y, const __m256i& vorder_y) { + auto vraw_x = _mm_loadl_epi64((const __m128i*)ptr); + auto vsrc_y = _mm256_broadcastq_epi64(vraw_x); + auto vordered_y = _mm256_permutevar8x32_epi32(vsrc_y, vorder_y); + auto vs_y = _mm256_sllv_epi32(vordered_y, vshift_y); auto v2_y = _mm256_and_si256(vs_y, vmask0_y); - auto vout_y = _mm256_shuffle_epi8(v2_y, vsfhl_mask_y); return vout_y; } @@ -1256,9 +1252,10 @@ inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstpt auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1,1,1,1,0,0,0,0); int elt_pad = utils::padto_le(unpack_elt, VElt); for (; i < elt_pad; i += VElt) { - auto vout = unpack_2bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y); + auto vout = unpack_2bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); if (std::is_same_v<_DST_T, int8_t>) { _mm256_storeu_si256((__m256i*)(dstptr + i), vout); } else { @@ -1778,6 +1775,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); const __m256i onesu8 = _mm256_set1_epi8(1); if (azptr) { @@ -1791,7 +1789,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int ik = 0; ik < blocksize; ik += KTILE) { auto va = _mm256_set1_epi32(*(int*)(a8ptr)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y); + auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); b2ptr += 8 * KTILE / 4; @@ -1826,7 +1824,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut for (int ik = 0; ik < blocksize; ik += KTILE) { auto va = _mm256_set1_epi32(*(int*)(a8ptr)); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y); + auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); b2ptr += 8 * KTILE / 4; } @@ -1876,6 +1874,7 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6); auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); const __m256i onesu8 = _mm256_set1_epi8(1); for (int ib = 0; ib < blks; ib += 1) { @@ -1887,7 +1886,7 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut auto va = _mm256_set1_epi32(*(int*)(a8ptr)); auto vabsa = _mm256_sign_epi8(va, va); for (int i = 0; i < NReg; i++) { - auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y); + auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); vb = _mm256_sign_epi8(vb, va); iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); b2ptr += 8 * KTILE / 4;