diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 8405feb3d..29110cf0d 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -550,7 +550,7 @@ static inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* z int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; if (col == NTILE) { @@ -764,7 +764,7 @@ static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t* int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; if (col == NTILE) { @@ -1022,7 +1022,7 @@ static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::b int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; @@ -1247,7 +1247,7 @@ static inline BTLA_CODE decompress_kblock_s1_s8(utils::bit1x8* bit1ptr, int8_t* int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + typedef BTLA_CODE (*decompfunc)(utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; if (col == NTILE) { @@ -1500,7 +1500,7 @@ static inline BTLA_CODE decompress_kblock_s5_s8(utils::bit4x2* bit4ptr, utils::b int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; @@ -1814,9 +1814,9 @@ static inline BTLA_CODE decompress_kblock_s7_s8(utils::bit4x2* bit4ptr, utils::b int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, - int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, - int row, int8_t* tmp, size_t tmpsize); + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr, + int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr; if (col == NTILE) { if constexpr (PackRow == 1) { @@ -2077,7 +2077,7 @@ static inline BTLA_CODE decompress_kblock_s6_s8(utils::bit4x2* bit4ptr, utils::b int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { if (zpptr) { - typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* zpptr, int8_t* dstptr, + typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); decompfunc func = nullptr;