Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add avx2 gemv for int2
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Apr 18, 2024
1 parent 2cab309 commit 5f3ec2e
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 7 deletions.
171 changes: 165 additions & 6 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr
return BTLA_CODE::Success;
}

static inline __m256i unpack_4bits_avx2(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y,
static inline __m256i unpack_2bits_avx2(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y,
const __m256i& vsfhl_mask_y) {
auto raw64 = *(uint64_t*)ptr;
auto rawlo32 = (raw64 & 0xffffffff) | (raw64 << 32);
Expand Down Expand Up @@ -1258,7 +1258,7 @@ inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstpt
13, 9, 5, 1, 12, 8, 4, 0);
int elt_pad = utils::padto_le(unpack_elt, VElt);
for (; i < elt_pad; i += VElt) {
auto vout = unpack_4bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y);
auto vout = unpack_2bits_avx2(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y);
if (std::is_same_v<_DST_T, int8_t>) {
_mm256_storeu_si256((__m256i*)(dstptr + i), vout);
} else {
Expand Down Expand Up @@ -1503,12 +1503,12 @@ static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const ut
}
for (int ik = 0; ik < blocksize; ik += 4) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik));
auto vabsa = _mm256_sign_epi8(va, va);
for (int i = 0; i < NReg; i++) {
auto vb =
kernel::avx2::unpack_4bits_avx2<false>((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask);
vb = _mm256_sign_epi8(vb, va);
va = _mm256_sign_epi8(va, va);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb);
}
}
const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib));
Expand Down Expand Up @@ -1726,11 +1726,11 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut
}
for (int iu = 0; iu < UnpackElt; iu++) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE));
auto vabsa = _mm256_sign_epi8(va, va);
for (int i = 0; i < NReg; i++) {
auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32));
vb = _mm256_sign_epi8(vb, va);
va = _mm256_sign_epi8(va, va);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb);
}
}
a8ptr += KTILE * UnpackElt;
Expand All @@ -1756,6 +1756,165 @@ static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const ut
}
return BTLA_CODE::Success;
}

template <typename ScaleT, int NTILE>
static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB<ScaleT>& B, float* C,
int k, int ld_scaleb, int blocksize, int8_t* tmp, size_t tmpsize) {
auto a8ptr = A.aptr;
auto b2ptr = reinterpret_cast<utils::bit2x4*>(B.b2ptr);
auto asptr = A.sptr;
auto azptr = A.zpptr;

int blks = k / blocksize;
int constexpr NReg = NTILE / 8;
// Initialize accumulator with zeros
__m256 acc[NReg];
int constexpr KTILE = 4;
for (int i = 0; i < NReg; i++) {
acc[i] = _mm256_setzero_ps();
}
uint64_t mask0 = 0xc0c0c0c0c0c0c0c0;
auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
const __m256i onesu8 = _mm256_set1_epi8(1);

if (azptr) {
for (int ib = 0; ib < blks; ib += 1) {
__m256i iacc[NReg];
__m256i bacc[NReg];
for (int i = 0; i < NReg; i++) {
iacc[i] = _mm256_setzero_si256();
bacc[i] = _mm256_setzero_si256();
}
for (int ik = 0; ik < blocksize; ik += KTILE) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb);
b2ptr += 8 * KTILE / 4;
}
a8ptr += KTILE;
}
const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib));
auto zp = int(azptr[ib]);
const __m256i v_a_zp = _mm256_set1_epi32(zp);
auto bsptr = B.sptr + ib * ld_scaleb;
for (int i = 0; i < NReg; i++) {
bacc[i] = _mm256_mullo_epi32(v_a_zp, bacc[i]);
iacc[i] = _mm256_sub_epi32(iacc[i], bacc[i]);
__m256 v_b_scale;
if constexpr (std::is_same_v<ScaleT, float>) {
v_b_scale = _mm256_loadu_ps(bsptr + i * 8);
} else if constexpr (std::is_same_v<ScaleT, utils::bf16>) {
auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8));
v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp);
}
v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale);
auto tmp = _mm256_cvtepi32_ps(iacc[i]);
acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]);
}
}
} else {
for (int ib = 0; ib < blks; ib += 1) {
__m256i iacc[NReg];
for (int i = 0; i < NReg; i++) {
iacc[i] = _mm256_setzero_si256();
}
for (int ik = 0; ik < blocksize; ik += KTILE) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb);
b2ptr += 8 * KTILE / 4;
}
a8ptr += KTILE;
}
const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib));
auto bsptr = B.sptr + ib * ld_scaleb;
for (int i = 0; i < NReg; i++) {
__m256 v_b_scale;
if constexpr (std::is_same_v<ScaleT, float>) {
v_b_scale = _mm256_loadu_ps(bsptr + i * 8);
} else if constexpr (std::is_same_v<ScaleT, utils::bf16>) {
auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8));
v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp);
}
v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale);
auto tmp = _mm256_cvtepi32_ps(iacc[i]);
acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]);
}
}
}

for (int i = 0; i < NReg; i++) {
_mm256_storeu_ps(C + i * 8, acc[i]);
}
return BTLA_CODE::Success;
}

template <typename ScaleT, int NTILE>
static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB<ScaleT>& B, float* C,
int k, int ld_scaleb, int blocksize, int8_t* tmp, size_t tmpsize) {
auto a8ptr = A.aptr;
auto b2ptr = reinterpret_cast<utils::bit2x4*>(B.b2ptr);
auto asptr = A.sptr;
auto azptr = A.zpptr;

int blks = k / blocksize;
int constexpr NReg = NTILE / 8;
// Initialize accumulator with zeros
__m256 acc[NReg];
int constexpr KTILE = 4;
for (int i = 0; i < NReg; i++) {
acc[i] = _mm256_setzero_ps();
}
uint64_t mask0 = 0xc0c0c0c0c0c0c0c0;
auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
auto vshift_y = _mm256_set_epi32(0, 2, 4, 6, 0, 2, 4, 6);
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
13, 9, 5, 1, 12, 8, 4, 0);
const __m256i onesu8 = _mm256_set1_epi8(1);

for (int ib = 0; ib < blks; ib += 1) {
__m256i iacc[NReg];
for (int i = 0; i < NReg; i++) {
iacc[i] = _mm256_setzero_si256();
}
for (int ik = 0; ik < blocksize; ik += KTILE) {
auto va = _mm256_set1_epi32(*(int*)(a8ptr));
auto vabsa = _mm256_sign_epi8(va, va);
for (int i = 0; i < NReg; i++) {
auto vb = unpack_2bits_avx2(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y);
vb = _mm256_sign_epi8(vb, va);
iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb);
b2ptr += 8 * KTILE / 4;
}
a8ptr += KTILE;
}
const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib));
auto bsptr = B.sptr + ib * ld_scaleb;
for (int i = 0; i < NReg; i++) {
__m256 v_b_scale;
if constexpr (std::is_same_v<ScaleT, float>) {
v_b_scale = _mm256_loadu_ps(bsptr + i * 8);
} else if constexpr (std::is_same_v<ScaleT, utils::bf16>) {
auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8));
v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp);
}
v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale);
auto tmp = _mm256_cvtepi32_ps(iacc[i]);
acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]);
}
}

for (int i = 0; i < NReg; i++) {
_mm256_storeu_ps(C + i * 8, acc[i]);
}
return BTLA_CODE::Success;
}
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
Expand Down
6 changes: 6 additions & 0 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ class GEMVWoqNBits {
return ref::gemv_3bit_u8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
if (B.nbits == 2) {
if (ISA_T >= BTLA_ISA::AVX2) {
return avx2::gemv_2bit_u8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
return ref::gemv_2bit_u8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
return BTLA_CODE::NotSupport;
Expand All @@ -991,6 +994,9 @@ class GEMVWoqNBits {
return ref::gemv_3bit_s8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
if (B.nbits == 2) {
if (ISA_T >= BTLA_ISA::AVX2) {
return avx2::gemv_2bit_s8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
return ref::gemv_2bit_s8s8_fp32<ScaleT, NTILE>(A, B, C, k, ld_scaleb, blocksize, (int8_t*)tmp, tmpsize);
}
return BTLA_CODE::NotSupport;
Expand Down
8 changes: 8 additions & 0 deletions bestla/bestla/ut/bestla_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,11 +803,19 @@ class UTWOQ_CompInt8 {
public:
UTWOQ_CompInt8() {
UT_START();
ut_s2();
ut_s3();
ut_s4();
// ut_s8();
}

void ut_s2() {
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, float>(1, 4096, 4096, BTLA_DTYPE::S2_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1, 4096, 4096, BTLA_DTYPE::S2_CLIP);
/*benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(2048, 4096, 4096, BTLA_DTYPE::S4_CLIP);*/
}

void ut_s3() {
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, float>(1, 4096, 4096, BTLA_DTYPE::S3_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1, 4096, 4096, BTLA_DTYPE::S3_CLIP);
Expand Down
12 changes: 11 additions & 1 deletion bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,11 +967,21 @@ class UT_CompInt8 {
public:
UT_CompInt8() {
UT_START();
ut_s2();
ut_s3();
ut_s4_newkblock();
ut_s4();
ut_s8();
}

void ut_s2() {
GetCPUDevice();
if (_cd->AVX_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
}
}

void ut_s3() {
GetCPUDevice();
if (_cd->AVX_VNNI()) {
Expand Down Expand Up @@ -1276,8 +1286,8 @@ class UT_CompInt8 {
}
};
#ifdef BTLA_UT_PROLOGUE_B
static UT_CompInt8 sUT_CompInt8;
#endif
static UT_CompInt8 sUT_CompInt8;

class UT_CompBf16 {
public:
Expand Down

0 comments on commit 5f3ec2e

Please sign in to comment.