Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix S3 quant error: add rounding and auto quant.
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Mar 13, 2024
1 parent 4ff4117 commit 1392bd7
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 36 deletions.
28 changes: 16 additions & 12 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,9 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr,
}
return BTLA_CODE::Success;
}

static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_s4(const float* srcptr, int8_t* dstptr, int row, int col,
int ld_src, int ld_dst, float* scales, int blocksize) {
template <BTLA_DTYPE QDT_T>
static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_auto(const float* srcptr, int8_t* dstptr, int row, int col,
int ld_src, int ld_dst, float* scales, int blocksize) {
int constexpr VLen = 16;
int col16 = utils::padto_le(col, VLen);
int i = 0;
Expand All @@ -889,13 +889,16 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_s4(const float* srcpt
_mm512_storeu_ps(tmp_min, vminval);
_mm512_storeu_ps(tmp_max, vmaxval);
_mm512_storeu_ps(tmp_abs, vabsval);
auto constexpr NBits = utils::bestla_dtype_bits(QDT_T);
int constexpr FullValue = 1 << (NBits - 1);
int constexpr GenValue = FullValue - 1;
for (int iv = 0; iv < VLen; iv++) {
int NVal = 7;
int NVal = GenValue;
auto sum = tmp_max[iv] + tmp_min[iv];
if (abs(sum) >= tmp_abs[iv] / 7.5) {
NVal = sum > 0.f ? -8 : 8;
if (abs(sum) >= tmp_abs[iv] / FullValue) {
NVal = sum > 0.f ? -FullValue : FullValue;
}
NVal = NVal << 4;
NVal = NVal << (8 - NBits);
tmp_abs[iv] = NVal;
}
auto vmag = _mm512_loadu_ps(tmp_abs);
Expand All @@ -913,8 +916,8 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_s4(const float* srcpt
for (; j < align_row; j += blocksize) simd_process_block(blocksize);
if (j < row) simd_process_block(row - align_row);
}
kernel::ref::quantize_f32_sign_int_rowblock<BTLA_DTYPE::S4_CLIP>(srcptr + i, dstptr + i, row, col - i, ld_src, ld_dst,
scales + i, nullptr, blocksize);
kernel::ref::quantize_f32_sign_int_rowblock<QDT_T>(srcptr + i, dstptr + i, row, col - i, ld_src, ld_dst,
scales + i, nullptr, blocksize);
return BTLA_CODE::Success;
}

Expand Down Expand Up @@ -985,13 +988,14 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr,
return BTLA_CODE::Success;
}

template <BTLA_DTYPE S4_T>
template <BTLA_DTYPE QDT_T>
static inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col,
int ld_src, int ld_dst, float* scales, int8_t* zero_points,
int blocksize) {
if (zero_points == nullptr)
if constexpr (S4_T == BTLA_DTYPE::S4_CLIP) {
return quantize_f32_sign_int_rowblock_sym_s4(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize);
if constexpr (QDT_T == BTLA_DTYPE::S4_CLIP || QDT_T == BTLA_DTYPE::S3_CLIP) {
return quantize_f32_sign_int_rowblock_sym_auto<QDT_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
blocksize);
} else {
return quantize_f32_sign_int_rowblock_sym(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize);
}
Expand Down
54 changes: 35 additions & 19 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,14 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i
static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr,
int row, int col, int ld_src, int ld_dst) {
assert(col % 128 == 0);

auto round3bit = [](int8_t src) {
int32_t dst = src;
dst = dst >= 0 ? dst + 16 : dst - 16;
dst = dst / 32;
dst = dst > 3 ? 3 : dst;
dst = dst < -4 ? -4 : dst;
return static_cast<int8_t>(dst);
};
auto bit2_interleave = [&](int8_t* src, int8_t* dst) {
for (int i = 0; i < 128 / 4; i++) {
dst[4 * i] = src[i];
Expand All @@ -191,30 +198,36 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x
}
};

int8_t round_buf[128];
int8_t interleave_buf[128];

for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 128) {
bit2_interleave(const_cast<int8_t*>(srcptr + i * ld_src + j), interleave_buf);
for (int k = 0; k < 128; k++) {
round_buf[k] = round3bit(const_cast<int8_t*>(srcptr + i * ld_src + j + k)[0]) << 5;
}
bit2_interleave(round_buf, interleave_buf);
for (int k = 0; k < 32; k++) {
bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5;
}
for (int k = j; k < j + 128; k += 8) {
bit1ptr[i * ld_dst / 8 + k / 8].a = round_buf[k - j] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].b = round_buf[k - j + 1] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].c = round_buf[k - j + 2] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].d = round_buf[k - j + 3] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].e = round_buf[k - j + 4] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].f = round_buf[k - j + 5] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].g = round_buf[k - j + 6] >> 7;
bit1ptr[i * ld_dst / 8 + k / 8].h = round_buf[k - j + 7] >> 7;
}
}
}
// store 1 bit without interleave as mask.
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 8) {
bit1ptr[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7;
}
}
return BTLA_CODE::Success;
Expand Down Expand Up @@ -864,7 +877,7 @@ static inline BTLA_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int r
return BTLA_CODE::Success;
}

template <BTLA_DTYPE S4_T>
template <BTLA_DTYPE QDT_T>
inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src,
int ld_dst, float* scales, int8_t* zero_points, int blocksize) {
int raw_blocksize = blocksize;
Expand Down Expand Up @@ -940,7 +953,10 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst
dstptr[(j + ij) * ld_dst + i] = x << 4;
}
};
auto s4auto_calc_store_scale_and_quantv_sym = [&](int blocksize) {
auto sNauto_calc_store_scale_and_quantv_sym = [&](int blocksize) {
auto constexpr NBits = utils::bestla_dtype_bits(QDT_T);
int constexpr FullValue = 1 << (NBits - 1);
int constexpr GenValue = FullValue - 1;
float maxval = std::numeric_limits<float>::min();
float minval = std::numeric_limits<float>::max();
float absmax = 0;
Expand All @@ -949,12 +965,12 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst
minval = std::min(minval, srcptr[(j + ij) * ld_src + i]);
absmax = std::max(absmax, std::abs(srcptr[(j + ij) * ld_src + i]));
}
int NVal = 7;
int NVal = GenValue;
auto sum = maxval + minval;
if (abs(sum) >= absmax / 7.5) {
NVal = sum > 0.f ? -8 : 8;
if (abs(sum) >= absmax / FullValue) {
NVal = sum > 0.f ? -FullValue : FullValue;
}
NVal = NVal << 4;
NVal = NVal << (8 - NBits);
float scale = absmax / NVal;
float rscale = 1.f / scale;
scales[j / raw_blocksize * ld_dst + i] = scale;
Expand All @@ -964,18 +980,18 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst
};

auto dispatch_calc = [&](int blocksize) {
switch (S4_T) {
switch (QDT_T) {
case BTLA_DTYPE::S8:
case BTLA_DTYPE::S3_CLIP:
if (zero_points == nullptr) {
s8_calc_store_scale_and_quantv_sym(blocksize);
} else {
s8_calc_store_scale_and_quantv_asym(blocksize);
}
break;
case BTLA_DTYPE::S3_CLIP:
case BTLA_DTYPE::S4_CLIP:
if (zero_points == nullptr) {
s4auto_calc_store_scale_and_quantv_sym(blocksize);
sNauto_calc_store_scale_and_quantv_sym(blocksize);
} else {
s8_calc_store_scale_and_quantv_asym(blocksize);
}
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,17 @@ class Transpose2D {

class QuantizeSignIntRowBlock {
public:
template <BTLA_ISA ISA_T, BTLA_DTYPE S4_T>
template <BTLA_ISA ISA_T, BTLA_DTYPE QDT_T>
static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst,
float* scales, int8_t* zero_points, int blocksize) {
#if CompileAVX512F()
if constexpr (utils::isa_base<ISA_T>::avx512f &&
S4_T != BTLA_DTYPE::S4_FULLRANGE) { // TODO(zhe): support simd version s4_fullrange quantization.
return avx512f::quantize_f32_sign_int_rowblock<S4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
QDT_T != BTLA_DTYPE::S4_FULLRANGE) { // TODO(zhe): support simd version s4_fullrange quantization.
return avx512f::quantize_f32_sign_int_rowblock<QDT_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, blocksize);
}
#endif
return ref::quantize_f32_sign_int_rowblock<S4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points,
return ref::quantize_f32_sign_int_rowblock<QDT_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points,
blocksize);
}
};
Expand Down
41 changes: 40 additions & 1 deletion bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,54 @@ class UT_BlockQunatize_F8 {
static UT_BlockQunatize_F8 sUT_BlockQunatize_F8;
#endif

class UT_BlockQunatize_S3S4 {
public:
UT_BlockQunatize_S3S4() {
UT_START();
CheckISA(AVX512F);
ut(127, 4096, 32, BTLA_DTYPE::S3_CLIP);
ut(4096, 4096, 32, BTLA_DTYPE::S3_CLIP);
ut(4096, 4096, 128, BTLA_DTYPE::S3_CLIP);
ut(127, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut(4096, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut(4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
}

void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) {
printf("%s DType %s: %d %d %d\n", __FUNCTION__, utils::bestla_dtype_str(QUANT_T), n, k, blocksize);
int ldb = n;
utils::aligned_vector<float> raw(n * k);
ut::fill_buffer_randn(raw.data(), raw.size(), -0.5f, 1.8f);

auto constexpr RuntimeISA = BTLA_ISA::AVX512F;
using PrologueB = prologue_b::gemm::WeightKBlockNInteger<gemm::SCoreRowNAvx512f<48, 8>, RuntimeISA>;
PrologueB kernel;
auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false);
avector<int8_t> buffer(ptr.mSize);
ptr.assign(buffer.data());
kernel.packWeight(n, k, raw.data(), ldb, &ptr, UT_Threading::get());
avector<float> dequant(n * k, 0);
kernel.unpackWeight(n, k, &ptr, dequant.data(), n, UT_Threading::get());
ut::buffer_error(raw.data(), dequant.data(), dequant.size(), 0.01f);
}
};
#ifdef BTLA_UT_PROLOGUE_B
// no proper threshold for this UT
//static UT_BlockQunatize_S3S4 sUT_BlockQunatize_S3S4;
#endif

class UT_S3_WOQ {
public:
UT_S3_WOQ() {
UT_START();
CheckISA(AVX512F);
ut<sAVX512F, BTLA_ISA::AVX512F>(1, 4096, 4096, 32, 56);
CheckISA(AVX512_VNNI);
ut<gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56);
CheckISA(AMX_BF16);
ut<sAMX_BF16, BTLA_ISA::AMX_BF16>(1, 4096, 4096, 32, 56);
CheckISA(AMX_INT8);
ut<gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56);
ut<gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56);
}

template <class GemmCore_T, BTLA_ISA ISA>
Expand Down

0 comments on commit 1392bd7

Please sign in to comment.