Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[BesTLA] Support RTN int2 weight #178

Merged
merged 23 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"CMAKE_BUILD_TYPE": "Debug",
"NS_PROFILING": "ON",
"NS_USE_OMP": "ON",
"BTLA_UT_DEBUG": "ON"
"BTLA_UT_DEBUG": "ON",
"BTLA_UT_BENCHMARK": "ON"
}
},
{
Expand Down
2 changes: 2 additions & 0 deletions bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum class BTLA_DTYPE : uint32_t {
EleBitsMask = 0xff,
EleBitsShift = 0,
EleBitsUndef = 0,
EleBits2 = 2,
EleBits3 = 3,
EleBits4 = 4,
EleBits8 = 8,
Expand Down Expand Up @@ -65,6 +66,7 @@ enum class BTLA_DTYPE : uint32_t {
DQ8_BNB = EleBits8 | TypeFloat | SubType4,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S2_CLIP = EleBits2 | TypeInt,
S3_CLIP = EleBits3 | TypeInt,
S4_CLIP = EleBits4 | TypeInt,
F4_E2M1 = EleBits4 | TypeFloat,
Expand Down
38 changes: 18 additions & 20 deletions bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,10 @@ class CpuDevice {
if (tmp[3] & (1U << 15)) mHybrid = true;
if (p) printf("!!!Hybrid:%d\t%x\t%x\t%x\t%x!!!\n", mHybrid, tmp[0], tmp[1], tmp[2], tmp[3]);
}
int total_cores = numcores * _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::SmtLevel);
if (total_cores <= 16) mClient = true;
if (mHybrid) {
int total_cores = numcores * _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::SmtLevel);
mClient = true;
std::vector<int> core_type(total_cores), core_id(total_cores), L1(total_cores), L2(total_cores);
std::map<int, int> core_id_count;

Expand Down Expand Up @@ -311,21 +313,14 @@ class CpuDevice {
for (auto& i : SMT_core) printf("%d,", i);
printf("\n");
}
if (!E_core.empty() && !P_core.empty()) {
mHybrid = !(E_core.empty() || P_core.empty()); // in case of bond core by external
if (!E_core.empty()) {
E_L1Cache = L1[E_core[0]];
E_L2Cache = L2[E_core[0]] / 4;
uint32_t scale = SMT_core.empty() ? 1 : 2;
L1Cache = E_L1Cache > L1[P_core[0]] / scale ? L1[P_core[0]] / scale : E_L1Cache;
L2Cache = E_L2Cache > L2[P_core[0]] / scale ? L2[P_core[0]] / scale : E_L2Cache;
} else if (!P_core.empty()) {
uint32_t scale = SMT_core.empty() ? 1 : 2;
L1Cache = L1[P_core[0]] / scale;
L2Cache = L2[P_core[0]] / scale;
mHybrid = false;
} else {
L1Cache = L1[E_core[0]];
L2Cache = L2[E_core[0]] / 4;
mHybrid = false;
};
if (!P_core.empty()) {
L1Cache = L1[P_core[0]];
L2Cache = L2[P_core[0]];
}
}
numcores = static_cast<int>(P_core.size() + E_core.size());
Expand Down Expand Up @@ -461,10 +456,11 @@ class CpuDevice {
}

bool isHybrid() { return mHybrid; }
bool isClient() { return mClient; }

protected:
uint32_t L2Cache, L1Cache, L3Cache;
bool mHybrid = false;
bool mHybrid = false, mClient = false;
bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16,
mHasAVX512_FP16;
int numcores;
Expand Down Expand Up @@ -506,7 +502,7 @@ class CpuRuntime {
mL1Cache = _cd->getL1CacheSize();
maxThreads = _cd->getThreads();
mHybrid = false;
if (_cd->isHybrid() && thread > _cd->getPcoreNum()) {
if (_cd->isClient() && thread > _cd->getPcoreNum()) {
if (thread > _cd->getPcoreNum() + _cd->getEcoreNum()) {
mL1Cache_P = mL1Cache / 2;
mL2Cache_P = mL2Cache / 2;
Expand All @@ -518,10 +514,12 @@ class CpuRuntime {
P_core_num = static_cast<int>(_cd->getPcoreNum());
E_core_num = thread - P_core_num;
}
mL1Cache_E = _cd->getL1CacheSize_E();
mL2Cache_E = _cd->getL2CacheSize_E();
mHybrid = true;
memcpy(PE, _cd->getPE(), int(BTLA_ISA::ISA_COUNT) * sizeof(float));
if (mHybrid) {
mL1Cache_E = _cd->getL1CacheSize_E();
mL2Cache_E = _cd->getL2CacheSize_E();
mHybrid = true;
memcpy(PE, _cd->getPE(), int(BTLA_ISA::ISA_COUNT) * sizeof(float));
}
}
}
float PE[int(BTLA_ISA::ISA_COUNT)];
Expand Down
5 changes: 4 additions & 1 deletion bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ class StdThreading : public IThreading {
reinterpret_cast<void*>(_cd->getSMTCores()), _cd->getSMTcoreNum() * sizeof(int));
} else {
core_order.resize(mThreadNum);
for (int i = 0; i < mThreadNum; i++) core_order[i] = i;
if (_cd->isClient())
for (int i = 0; i < mThreadNum; i++) core_order[i] = 2 * i;
else
for (int i = 0; i < mThreadNum; i++) core_order[i] = i;
}
_cd->core_bond(core_order[0]);
if (cr->mHybrid) {
Expand Down
70 changes: 70 additions & 0 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,27 @@ class WeightKBlockNInteger {
assert(ret == BTLA_CODE::Success);
}

static void compressBit2Weight(const int N, const int K, const int8_t* B, int8_t* dstptr,
parallel::IThreading* threading) {
// TODO(zhe): 1D parallel compress
parallel::Scheduler2D _para({threading->num_threads(), 1, K * N, 1, 64});
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(dstptr);
threading->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp({tidx});
_para.getIndex(thdp);
if (thdp.valid) {
auto ret =
kernel::wrapper::CompressBit2::forward<ISA_T>(B + thdp.loc[1], bit2ptr + thdp.loc[1] / 4, thdp.size[1]);
assert(ret == BTLA_CODE::Success);
(void)ret;
}
});
}

static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype,
parallel::IThreading* threading) {
if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, threading);
if (qtype == BTLA_DTYPE::S2_CLIP) return compressBit2Weight(N, K, B, dstptr, threading);
parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE});
threading->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp({tidx});
Expand Down Expand Up @@ -629,6 +647,8 @@ class WeightKBlockNInteger {
return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) {
return getQ2Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -729,6 +749,13 @@ class WeightKBlockNInteger {
kernel::wrapper::DecompressKBlockS3S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE,
k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) {
int8_t* bit2_ptr = wptr->template WPtr<int8_t>();
auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad;
assert(elt_offset % 4 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit2_ptr + elt_offset / 4);
kernel::wrapper::DecompressKBlockS2S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S2_CLIP>(
bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -776,6 +803,16 @@ class WeightKBlockNInteger {
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW,
ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) {
int8_t* bit2_ptr = wptr->template WPtr<int8_t>();
auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad;
assert(elt_offset % 4 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit2_ptr + elt_offset / 4);
kernel::wrapper::DecompressKBlockS2Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, float,
BTLA_DTYPE::S2_CLIP>(
bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -809,6 +846,16 @@ class WeightKBlockNInteger {
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW,
ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) {
int8_t* bit2_ptr = wptr->template WPtr<int8_t>();
auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad;
assert(elt_offset % 4 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit2_ptr + elt_offset / 4);
kernel::wrapper::DecompressKBlockS2Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, utils::bf16,
BTLA_DTYPE::S2_CLIP>(
bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -884,6 +931,26 @@ class WeightKBlockNInteger {
return BTLA_CODE::Success;
}

static inline BTLA_CODE getQ2Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset,
const Param& _param, void* tmpcache, size_t cachesize) {
auto wptr = _param.packedW;
int8_t* bit2_ptr = wptr->template WPtr<int8_t>();
auto KPad = wptr->mKPad;
auto NPad = wptr->mNPad;
int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW;
auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
auto elt_offset = base_offset + i * KPad;
assert(elt_offset % 4 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit2_ptr + elt_offset / 4);
kernel::wrapper::DecompressKBlockS2S8Fp<int8_t>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast<int8_t*>(tmpcache),
cachesize);
}
*dststep = k_size;
return BTLA_CODE::Success;
}

virtual inline void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst,
float* scales, int8_t* zero_points, void* stor) {
auto ptr = reinterpret_cast<StorageWeight*>(stor);
Expand All @@ -897,6 +964,9 @@ class WeightKBlockNInteger {
} else if (quant_dtype == BTLA_DTYPE::S3_CLIP) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
} else if (quant_dtype == BTLA_DTYPE::S2_CLIP) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S2_CLIP>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
} else {
assert(0);
}
Expand Down
4 changes: 4 additions & 0 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) {
return "unsigned_int8";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S3_CLIP:
return "int3_clip";
case BTLA_DTYPE::S2_CLIP:
return "int2_clip";
case BTLA_DTYPE::F4_E2M1:
return "fp4_e2m1";
case BTLA_DTYPE::F4_BNB:
Expand Down
76 changes: 76 additions & 0 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once
#include "bestla.h"
#include "bestla_utils.h"
#include "kernel_jit.h"
#include "kernel_ref.h"
#if CompileAVX2()
#include <immintrin.h>
Expand Down Expand Up @@ -1154,6 +1155,81 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
return BTLA_CODE::Success;
}

template <BTLA_DTYPE S3_T, typename _DST_T>
inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr,
int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) {
auto head_ignore_num = interleave_n_offset % 128;
const __m256i lowMask = _mm256_set1_epi8(0x03);
const __m256i highMask = _mm256_set1_epi8(0x04);
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));

auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) {
__m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1);
int32_t* bit1_ptr = reinterpret_cast<int32_t*>(src2);
for (int i = 0; i < 4; i++) {
auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]);
bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1);
bit1x32 = _mm256_and_si256(bit1x32, bit1Mask);
bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2);
bit1x32 = _mm256_and_si256(highMask, bit1x32);

auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i));
auto res = _mm256_add_epi8(bit1x32, bit2x32);
res = _mm256_slli_epi32(res, 5);
_mm256_storeu_si256((__m256i*)(dst + 32 * i), res);
}
};
int compress_wei_ptr_offset = 0;
if (head_ignore_num != 0) {
assert(head_ignore_num % 8 == 0);

auto base_bit2ptr = bit2ptr - head_ignore_num / 4;
auto base_bit1ptr = bit1ptr - head_ignore_num / 8;
auto head_write_num = 128 - head_ignore_num;
bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp);
for (int i = 0; i < head_write_num; i++) dstptr[i] = tmp[head_ignore_num + i];
compress_wei_ptr_offset += head_write_num;
unpack_elt -= head_write_num;
}
auto body_loop = unpack_elt / 128;
auto tail_proc_num = unpack_elt % 128;

bestla::kernel::jit::DecompresssS3::forward_avx2(bit2ptr + compress_wei_ptr_offset / 4,
bit1ptr + compress_wei_ptr_offset / 8,
dstptr + compress_wei_ptr_offset, tmp, body_loop * 128);
compress_wei_ptr_offset += body_loop * 128;
if (tail_proc_num > 0) {
bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8,
tmp);
for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = tmp[i];
}
return BTLA_CODE::Success;
}

template <BTLA_DTYPE _S3_T, typename _DST_T, int _PACK_ROW, typename _ST>
static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr,
_DST_T* dstptr, int interleave_n_offset, int row, int col,
_ST* scales, int8_t* zero_points, int k_offset, int kblock,
int NPad, void* tmp, size_t tmpsize) {
auto unpack_elt = row * col;
decompress_kblock_s3_s8fp<_S3_T>(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt,
reinterpret_cast<int8_t*>(tmp), tmpsize);
// TODO(zhe): simd version
for (int i = 0; i < row; i++) {
int kpos = (k_offset + i) / kblock;
auto sptr = scales + kpos * NPad;
for (int j = 0; j < col; j++) {
float tmp = static_cast<float>(dstptr[i * col + j]);
if (zero_points != nullptr) tmp -= static_cast<float>(zero_points[kpos * NPad + j / _PACK_ROW]);
dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]);
}
}

return BTLA_CODE::Success;
}

inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) {
const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
static const auto mask_exp = _mm256_set1_epi32(0x7f800000);
Expand Down
23 changes: 10 additions & 13 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,32 +644,29 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
_mm512_storeu_si512((__m512i*)dst, zmm1);
_mm512_storeu_si512((__m512i*)(dst + 64), zmm2);
};

assert(head_ignore_num % 8 == 0);

auto base_bit2ptr = bit2ptr - head_ignore_num / 4;
auto base_bit1ptr = bit1ptr - head_ignore_num / 8;
int compress_wei_ptr_offset = 0;
int8_t* s8_ptr = reinterpret_cast<int8_t*>(tmp);
auto head_write_num = 128 - head_ignore_num;
if (head_ignore_num != 0) {
assert(head_ignore_num % 8 == 0);

auto base_bit2ptr = bit2ptr - head_ignore_num / 4;
auto base_bit1ptr = bit1ptr - head_ignore_num / 8;
auto head_write_num = 128 - head_ignore_num;
bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp);
for (int i = 0; i < head_write_num; i++) dstptr[i] = s8_ptr[head_ignore_num + i];
for (int i = 0; i < head_write_num; i++) dstptr[i] = tmp[head_ignore_num + i];
compress_wei_ptr_offset += head_write_num;
unpack_elt -= head_write_num;
}

auto body_loop = (unpack_elt - head_write_num % 128) / 128;
auto tail_proc_num = (unpack_elt - head_write_num % 128) % 128;
auto body_loop = unpack_elt / 128;
auto tail_proc_num = unpack_elt % 128;

bestla::kernel::jit::DecompresssS3::forward_avx512f(bit2ptr + compress_wei_ptr_offset / 4,
bit1ptr + compress_wei_ptr_offset / 8,
dstptr + compress_wei_ptr_offset, tmp, body_loop * 128);
compress_wei_ptr_offset += body_loop * 128;
if (tail_proc_num > 0) {
bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp);
bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8,
tmp);
for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = s8_ptr[i];
for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = tmp[i];
}
return BTLA_CODE::Success;
}
Expand Down
Loading
Loading