Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix compile error
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Apr 29, 2024
1 parent 58284ce commit 58ee0e3
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -3113,45 +3113,6 @@ static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::b
return BTLA_CODE::Success;
}

template <int PackRow, int NTILE, typename DST_T>
inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row,
void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset,
int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) {
int constexpr NReg = NTILE / 8;
const auto DstSize = row * NTILE * sizeof(DST_T);
const auto S8Size = row * NTILE * sizeof(int8_t);
auto tmps8ptr = (int8_t*)dstptr;
tmps8ptr += DstSize - S8Size;
auto ret = decompress_kblock_s3_s8<PackRow, NTILE>(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset,
k_offset, row, NTILE, tmp, tmpsize);
assert(ret == BTLA_CODE::Success);
return decompress_kblock_s8_fp_row<PackRow, NTILE, DST_T>(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset,
n_offset, blocksize, ldzp, tmp, tmpsize);
}

template <int PackRow, int NTILE, typename DST_T>
inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col,
void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset,
int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) {
auto ret = BTLA_CODE::NotSupport;
if (col == NTILE) {
int head_end = utils::padto(k_offset, blocksize);
head_end = std::min(head_end, k_offset + row);
int head_size = head_end - k_offset;
if (head_size > 0) {
decompress_kblock_s3_fp_row<PackRow, NTILE, DST_T>(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points,
k_offset, n_offset, blocksize, ldzp, tmp, tmpsize);
}
int body_size = row - head_size;
if (body_size > 0) {
decompress_kblock_s3_fp_row<PackRow, NTILE, DST_T>(
b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_,
sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize);
}
return BTLA_CODE::Success;
}
return ret;
}

template <int PackRow, int NTILE, typename DST_T>
inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype,
Expand Down Expand Up @@ -3440,6 +3401,46 @@ inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, in
return ret;
}

template <int PackRow, int NTILE, typename DST_T>
inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row,
void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset,
int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) {
int constexpr NReg = NTILE / 8;
const auto DstSize = row * NTILE * sizeof(DST_T);
const auto S8Size = row * NTILE * sizeof(int8_t);
auto tmps8ptr = (int8_t*)dstptr;
tmps8ptr += DstSize - S8Size;
auto ret = decompress_kblock_s3_s8<PackRow, NTILE>(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset,
k_offset, row, NTILE, tmp, tmpsize);
assert(ret == BTLA_CODE::Success);
return decompress_kblock_s8_fp_row<PackRow, NTILE, DST_T>(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset,
n_offset, blocksize, ldzp, tmp, tmpsize);
}

template <int PackRow, int NTILE, typename DST_T>
inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col,
void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset,
int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) {
auto ret = BTLA_CODE::NotSupport;
if (col == NTILE) {
int head_end = utils::padto(k_offset, blocksize);
head_end = std::min(head_end, k_offset + row);
int head_size = head_end - k_offset;
if (head_size > 0) {
decompress_kblock_s3_fp_row<PackRow, NTILE, DST_T>(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points,
k_offset, n_offset, blocksize, ldzp, tmp, tmpsize);
}
int body_size = row - head_size;
if (body_size > 0) {
decompress_kblock_s3_fp_row<PackRow, NTILE, DST_T>(
b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_,
sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize);
}
return BTLA_CODE::Success;
}
return ret;
}

template <typename T>
static inline __m512 load_T_fp32(const T* srcptr) {
__m512 vtmp;
Expand Down

0 comments on commit 58ee0e3

Please sign in to comment.