Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
use vec register instead of general register
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Apr 18, 2024
1 parent 6665f14 commit 6062158
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -1231,16 +1231,12 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr
}

static inline __m256i unpack_2bits_avx2(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y,
const __m256i& vsfhl_mask_y) {
auto raw64 = *(uint64_t*)ptr;
auto rawlo32 = (raw64 & 0xffffffff) | (raw64 << 32);
auto rawhi32 = (raw64 & 0xffffffff00000000) | (raw64 >> 32);
auto vlo_x = _mm_set_epi64x(*(int64_t*)&rawlo32, *(int64_t*)&rawlo32);
auto vhi_x = _mm_set_epi64x(*(int64_t*)&rawhi32, *(int64_t*)&rawhi32);
auto vsrc_y = _mm256_set_m128i(vhi_x, vlo_x);
auto vs_y = _mm256_sllv_epi32(vsrc_y, vshift_y);
const __m256i& vsfhl_mask_y, const __m256i& vorder_y) {
auto vraw_x = _mm_loadl_epi64((const __m128i*)ptr);
auto vsrc_y = _mm256_broadcastq_epi64(vraw_x);
auto vordered_y = _mm256_permutevar8x32_epi32(vsrc_y, vorder_y);
auto vs_y = _mm256_sllv_epi32(vordered_y, vshift_y);
auto v2_y = _mm256_and_si256(vs_y, vmask0_y);

auto vout_y = _mm256_shuffle_epi8(v2_y, vsfhl_mask_y);
return vout_y;
}
Expand All @@ -1256,9 +1252,10 @@ inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstpt
auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm256_set_epi32(1,1,1,1,0,0,0,0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_2bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y);
auto vout = unpack_2bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
if (std::is_same_v<_DST_T, int8_t>) {
_mm256_storeu_si256((__m256i*)(dstptr + i), vout);
} else {
Expand Down Expand Up @@ -1778,6 +1775,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut
auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
const __m256i onesu8 = _mm256_set1_epi8(1);

if (azptr) {
Expand All @@ -1791,7 +1789,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut
for (int ik = 0; ik < blocksize; ik += KTILE) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb);
b2ptr += 8 * KTILE / 4;
Expand Down Expand Up @@ -1826,7 +1824,7 @@ static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const ut
for (int ik = 0; ik < blocksize; ik += KTILE) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
b2ptr += 8 * KTILE / 4;
}
Expand Down Expand Up @@ -1876,6 +1874,7 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut
auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
const __m256i onesu8 = _mm256_set1_epi8(1);

for (int ib = 0; ib < blks; ib += 1) {
Expand All @@ -1887,7 +1886,7 @@ static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const ut
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
auto vabsa = _mm256_sign_epi8(va, va);
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y);
vb = _mm256_sign_epi8(vb, va);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb);
b2ptr += 8 * KTILE / 4;
Expand Down

0 comments on commit 6062158

Please sign in to comment.