diff --git a/CMakePresets.json b/CMakePresets.json index 3a0694af9..0d470f1ed 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -8,7 +8,13 @@ "generator": "Ninja", "binaryDir": "${sourceDir}/out/build/${presetName}", "installDir": "${sourceDir}/out/install/${presetName}", - "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "NS_PROFILING": "ON", + "NS_USE_OMP": "ON", + "BTLA_UT_DEBUG": "ON", + "BTLA_UT_BENCHMARK": "ON" + }, "condition": { "type": "equals", "lhs": "${hostSystemName}", @@ -107,16 +113,32 @@ "BTLA_UT_OPENMP": "OFF" } }, + { + "name": "x64-debug-sycl", + "displayName": "x64 Debug SYCL", + "description": "x64 Debug SYCL", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "BTLA_UT_DEBUG": "ON", + "BTLA_UT_ALL": "OFF", + "BTLA_SYCL": "ON", + "BTLA_UT_BENCHMARK": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "icx" + } + }, { "name": "x64-release-sycl", - "displayName": "x64 Release SYCL", + "displayName": "x64 Release for SYCL", "description": "x64 SYCL", - "inherits": "x64-debug", + "inherits": "x64-debug-sycl", "cacheVariables": { - "CMAKE_CXX_COMPILER": "icx-cl", - "CMAKE_C_COMPILER": "icx-cl", - "CMAKE_BUILD_TYPE": "Release", - "BTLA_UT_ALL": "ON" + "CMAKE_BUILD_TYPE": "Release" } } ] diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index e11ea875c..9c44e2fcc 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -40,6 +40,7 @@ if(BTLA_SYCL) file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp) add_compile_definitions(BTLA_SYCL) list(APPEND sycl_libs IntelSYCL::SYCL_CXX) + add_compile_options(-march=native) #add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file") endif(BTLA_SYCL) diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index 2a02416a6..8a582d551 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -340,14 +340,17 @@ class CpuDevice { case 9: // ALD PE[int(BTLA_ISA::AVX2)] = 3.0f; PE[int(BTLA_ISA::AVX_VNNI)] = 5.0f; + PE[int(BTLA_ISA::NoSIMD)] = 3.5f; break; case 10: // MTL PE[int(BTLA_ISA::AVX2)] = 2.2f; PE[int(BTLA_ISA::AVX_VNNI)] = 3.0f; + PE[int(BTLA_ISA::NoSIMD)] = 3.0f; break; case 11: // RPL PE[int(BTLA_ISA::AVX2)] = 1.8f; PE[int(BTLA_ISA::AVX_VNNI)] = 2.6f; + PE[int(BTLA_ISA::NoSIMD)] = 3.0f; break; } } @@ -488,7 +491,7 @@ class CpuRuntime { inline void adjustPE(const BTLA_ISA isa, const float PE_) { // printf("Adjust:%d,%f\n",int(isa),PE_); - PE[int(isa)] *= PE_; + PE[int(isa)] = PE[int(isa)] * PE_ * 0.7 + PE[int(isa)] * 0.3; } size_t mL2Cache, mL1Cache, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0; @@ -514,7 +517,7 @@ class CpuRuntime { P_core_num = static_cast(_cd->getPcoreNum()); E_core_num = thread - P_core_num; } - if (mHybrid) { + if (_cd->isHybrid()) { mL1Cache_E = _cd->getL1CacheSize_E(); mL2Cache_E = _cd->getL2CacheSize_E(); mHybrid = true; diff --git a/bestla/bestla/bestla_epilogue.h b/bestla/bestla/bestla_epilogue.h index f2228c22f..3360688f5 100644 --- a/bestla/bestla/bestla_epilogue.h +++ b/bestla/bestla/bestla_epilogue.h @@ -37,10 +37,16 @@ class AccumulatorWriteBack { using DType = _DST_T; using Param = ParamAccumulatorWriteBack; - BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; + if constexpr (std::is_same_v<_SRC_T, DType>) { + if (cacheptr == cptr) { + return BTLA_CODE::Success; + } + } + return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, _param.ldc, _param.elt_const_v); } @@ -50,8 +56,8 @@ template class CustomAccumulatorWriteBackWithEltop { public: using Param = ParamAccumulatorWriteBack<_DST_T>; - BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { @@ -95,8 +101,8 @@ class AlphaBetaProcessFp32 { public: using Param = ParamAlphaBetaProcess; - BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto DOffset = M_offset * _param.ldd + N_offset; auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; @@ -118,9 +124,9 @@ template class CompFp32BlockEpilogue { public: using Param = ParamCompFp32BlockEpilogue; - BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { + static BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, + const int N_offset, const int K_offset, const int M, const int N, const Param& _param, + void* tmpcache, size_t cachesize) { auto ret = BTLA_CODE::NotSupport; if (_param.scaledtype == BTLA_DTYPE::F32) { ret = kernel::wrapper::CompFp32BlockScale::template forward( @@ -169,8 +175,8 @@ template class DequantInt32ToFp32 { public: using Param = ParamDequantInt32ToFp32; - BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, @@ -198,9 +204,9 @@ template class CompInt8BlockEpilogue { public: using Param = ParamCompInt8BlockEpilogue; - BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { + static BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, + const int N_offset, const int K_offset, const int M, const int N, const Param& _param, + void* tmpcache, size_t cachesize) { BTLA_CODE ret = BTLA_CODE::NotSupport; float* scab = nullptr; size_t ScaleBTmpSize = N * sizeof(float); @@ -280,8 +286,8 @@ template class ZpDequantInt32ToFp32 { public: using Param = ParamZpDequantInt32ToFp32; - BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, @@ -321,8 +327,8 @@ template class AlphaBetaProcessS32U8 { public: using Param = ParamAlphaBetaProcessS32U8; - BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, + const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { auto COffset = M_offset * _param.ldc + N_offset; auto cptr = _param.C + COffset; return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, diff --git a/bestla/bestla/bestla_gemm.h b/bestla/bestla/bestla_gemm.h index 88bcc529c..9859d4e29 100644 --- a/bestla/bestla/bestla_gemm.h +++ b/bestla/bestla/bestla_gemm.h @@ -1222,19 +1222,21 @@ class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { } }; -template +template class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { public: static int constexpr RegLen = 8, PackRow = 4; static_assert(_NTILE % RegLen == 0); static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr KeepRegs = std::is_same_v ? 1 : 3; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - KeepRegs) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - KeepRegs); static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; static auto constexpr ISA = BTLA_ISA::AVX_VNNI; - static auto constexpr COMPUTE = CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_INT32 : CompType::COMP_INT8_SS_INT32; + using AType = AT; typedef int8_t BType; typedef int32_t CType; struct params { @@ -1285,7 +1287,10 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { void assign_regs() { CRegCount = MRegs * NRegs; ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; + if (std::is_same_v) { + TmpRegCount = 2; + } + BRegCount = RegCount - ARegCount - CRegCount - TmpRegCount; if (BRegCount < NRegs) { BRegCount = 0; ARegCount = BRegCount + 1; @@ -1297,8 +1302,7 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { BReg = CReg + CRegCount; AReg = BReg + BRegCount; TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; + assert(TmpReg + TmpRegCount <= RegCount); } void generate_mtile(int _mtile) { @@ -1379,9 +1383,17 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { } for (int mm = 0; mm < _mtile; mm++) { vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } add(reg_tmp1, reg_astride); for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } else { + vpsignb(vreg_t(TmpReg), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } } } } else if (BRegCount == 0) { @@ -1389,10 +1401,272 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { int mm_re = utils::remainsize(mm, _mtile, ARegCount); for (int imm = 0; imm < mm_re; imm++) { vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg + imm), vreg_t(AReg + imm)); + } add(reg_tmp1, reg_astride); for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } else { + vmovups(vreg_t(TmpReg), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg), vreg_t(TmpReg), vreg_t(AReg + imm)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +using AvxvnniN8P4U8 = AvxvnniN8P4; + +template +using AvxvnniN8P4S8 = AvxvnniN8P4; + +template +class Avx2vnniN8P4 : protected bestla::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr KeepRegs = std::is_same_v ? 3 : 5; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - KeepRegs) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - KeepRegs); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX2; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_INT32 : CompType::COMP_INT8_SS_INT32; + using AType = AT; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + const int16_t one = 1; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + if (std::is_same_v) { + TmpRegCount = 4; + } else { + TmpRegCount = 2; + } + BRegCount = RegCount - ARegCount - CRegCount - TmpRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg + TmpRegCount <= RegCount); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + vpbroadcastw(vreg_t(TmpReg + 0), ptr[parambase + OFFSET(one)]); + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(BReg + i), + vreg_t(TmpReg + 0)); + } else { + vpsignb(vreg_t(TmpReg + 3), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg + imm), vreg_t(AReg + imm)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes], vreg_t(TmpReg + 0)); + } else { + vmovups(vreg_t(TmpReg + 3), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg + 3), vreg_t(TmpReg + 3), vreg_t(AReg + imm)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } } } } @@ -1442,6 +1716,12 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { } }; +template +using Avx2vnniN8P4U8 = Avx2vnniN8P4; + +template +using Avx2vnniN8P4S8 = Avx2vnniN8P4; + template class Amxbf16N16P2 : protected bestla::xbyak::JitAmxbf16 { public: @@ -2520,7 +2800,7 @@ class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { } }; -template +template class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { public: static int constexpr RegLen = 8, PackRow = 4; @@ -2531,8 +2811,9 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; static int constexpr KUNROLL = 2; static auto constexpr ISA = BTLA_ISA::AVX_VNNI; - static auto constexpr COMPUTE = CompType::COMP_INT8_US_FP32; - typedef uint8_t AType; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_FP32 : CompType::COMP_INT8_SS_FP32; + using AType = AT; typedef int8_t BType; typedef float CType; @@ -2703,9 +2984,19 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { if (BRegCount == 0) { for (int mm = 0; mm < _mtile; mm++) { vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } add(reg_tmp1, reg_astride); for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } else { + vmovups(vreg_t(TmpReg), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg), vreg_t(TmpReg), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } } } } else { @@ -2714,9 +3005,17 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { } for (int mm = 0; mm < _mtile; mm++) { vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } add(reg_tmp1, reg_astride); for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } else { + vpsignb(vreg_t(TmpReg), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } } } } @@ -2863,11 +3162,387 @@ class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { outLocalLabel(); } }; +template +using AvxvnniN8P4U8 = kblock::AvxvnniN8P4; +template +using AvxvnniN8P4S8 = kblock::AvxvnniN8P4; -template -class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { +template +class Avx2vnniN8P4 : protected bestla::xbyak::JitAvx2 { public: - static int constexpr RegLen = 16, PackRow = 4; + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr TmpReserve = std::is_same_v ? 2 : 4; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - (TmpReserve + 1)) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - (TmpReserve + 1)); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX2; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_FP32 : CompType::COMP_INT8_SS_FP32; + using AType = AT; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + float kscale; + const uint16_t one = 1; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - CRegCount - CRegCount - ARegCount - TmpReserve; + if (BRegCount >= NRegs) { + BRegCount = NRegs; + } else { + BRegCount = 0; + } + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= TmpReserve); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + vpbroadcastw(vreg_t(TmpReg + 0), ptr[parambase + OFFSET(one)]); + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes], vreg_t(TmpReg + 0)); + } else { + vmovups(vreg_t(TmpReg + 3), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg + 3), vreg_t(TmpReg + 3), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } else { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(BReg + i), + vreg_t(TmpReg + 0)); + } else { + vpsignb(vreg_t(TmpReg + 3), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmulps(vreg_t(AReg), vreg_t(TmpReg), vreg_t(BReg + i)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } else { + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmovups(vreg_t(AReg), ptr[reg_tmp2 + i * VecBytes]); + vmulps(vreg_t(AReg), vreg_t(TmpReg)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } + } + + void generate_zp_correction(int _mtile) { + inLocalLabel(); + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + cmp(reg_tmp, 0); + je(".NOZP", T_NEAR); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + + vbroadcastss(vreg_t(TmpReg), ptr[parambase + OFFSET(kscale)]); + auto& reg_kscale = reg_tmp4; + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_redB + i * VecBytes]); + } + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), vreg_t(BReg + j)); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } else { + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), ptr[reg_redB + j * VecBytes]); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + L(".NOZP"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; +template +using Avx2vnniN8P4U8 = kblock::Avx2vnniN8P4; +template +using Avx2vnniN8P4S8 = kblock::Avx2vnniN8P4; + +template +class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; static_assert(_NTILE % RegLen == 0); static_assert(_MTILE % RegLen == 0); static int constexpr NRegs = _NTILE / RegLen; @@ -3404,9 +4079,41 @@ class ICoreRowNAvx512vnniKBlock : public CoreCodeBase -class ICoreRowNAvxvnni : public CoreCodeBase { +class ICoreRowNAvxvnni : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; + + void forward(uint8_t* matA, int8_t* matB, int32_t* matC, int _m, int _n, int _k, int _astride, int _bstride, + int _cstride, int kpos, void* tmpcache, size_t cachesize) { + auto param = typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, _k, _n, kpos == 0 ? 1 : 0}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + +template +class ICoreRowNAvxvnniSS : public CoreCodeBase { public: - using Code = typename CoreCodeBase::Code; + using Code = typename CoreCodeBase::Code; + + void forward(int8_t* matA, int8_t* matB, int32_t* matC, int _m, int _n, int _k, int _astride, int _bstride, + int _cstride, int kpos, void* tmpcache, size_t cachesize) { + auto param = typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, _k, _n, kpos == 0 ? 1 : 0}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + +template +class ICoreRowNAvx2vnni : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; void forward(uint8_t* matA, int8_t* matB, int32_t* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, int kpos, void* tmpcache, size_t cachesize) { @@ -3420,9 +4127,25 @@ class ICoreRowNAvxvnni : public CoreCodeBase }; template -class ICoreRowNAvxvnniKBlock : public CoreCodeBase { +class ICoreRowNAvx2vnniSS : public CoreCodeBase { public: - using Code = typename CoreCodeBase::Code; + using Code = typename CoreCodeBase::Code; + + void forward(int8_t* matA, int8_t* matB, int32_t* matC, int _m, int _n, int _k, int _astride, int _bstride, + int _cstride, int kpos, void* tmpcache, size_t cachesize) { + auto param = typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, _k, _n, kpos == 0 ? 1 : 0}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + +template +class ICoreRowNAvxvnniKBlock : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; void forward(uint8_t* matA, int8_t* matB, float* matC, uint8_t* zpA, float* scaleA, int _ldsa, float* scaleB, float* reduceB, int _ldsb, int _m, int _n, int _k, int _kblock, int _astride, int _bstride, int _cstride, int kpos, float kscale, void* tmpcache, size_t cachesize) { @@ -3437,6 +4160,60 @@ class ICoreRowNAvxvnniKBlock : public CoreCodeBase +class ICoreRowNAvxvnniKBlockSS : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; + void forward(int8_t* matA, int8_t* matB, float* matC, int8_t* zpA, float* scaleA, int _ldsa, float* scaleB, + float* reduceB, int _ldsb, int _m, int _n, int _k, int _kblock, int _astride, int _bstride, int _cstride, + int kpos, float kscale, void* tmpcache, size_t cachesize) { + auto param = + typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, nullptr, scaleA, _ldsa, + scaleB, reduceB, _ldsb, _k, _n, _kblock, kpos == 0 ? 1 : 0, kscale}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + +template +class ICoreRowNAvx2vnniKBlock : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; + void forward(uint8_t* matA, int8_t* matB, float* matC, uint8_t* zpA, float* scaleA, int _ldsa, float* scaleB, + float* reduceB, int _ldsb, int _m, int _n, int _k, int _kblock, int _astride, int _bstride, int _cstride, + int kpos, float kscale, void* tmpcache, size_t cachesize) { + auto param = typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, zpA, scaleA, + _ldsa, scaleB, reduceB, _ldsb, _k, _n, _kblock, kpos == 0 ? 1 : 0, + kscale}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + +template +class ICoreRowNAvx2vnniKBlockSS : public CoreCodeBase { + public: + using Code = typename CoreCodeBase::Code; + void forward(int8_t* matA, int8_t* matB, float* matC, int8_t* zpA, float* scaleA, int _ldsa, float* scaleB, + float* reduceB, int _ldsb, int _m, int _n, int _k, int _kblock, int _astride, int _bstride, int _cstride, + int kpos, float kscale, void* tmpcache, size_t cachesize) { + auto param = + typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, nullptr, scaleA, _ldsa, + scaleB, reduceB, _ldsb, _k, _n, _kblock, kpos == 0 ? 1 : 0, kscale}; + if (_m <= Code::MTILE) { + this->mCodes[_m - 1].mKernel(¶m); + } else { + assert(0); + } + } +}; + template class ICoreRowNAmxint8 : public CoreCodeBaseAMX { public: diff --git a/bestla/bestla/bestla_jit.h b/bestla/bestla/bestla_jit.h index d1b76467b..b1a0fa093 100644 --- a/bestla/bestla/bestla_jit.h +++ b/bestla/bestla/bestla_jit.h @@ -116,6 +116,13 @@ class JitAvx2 : protected JitAvx { vpmovzxwd(dst, addr); vpslld(dst, dst, 16); } + + void vpdpbusds_(const Xbyak::Xmm& sum4, const Xbyak::Xmm& sum2, const Xbyak::Xmm& x, const Xbyak::Operand& op, + const Xbyak::Xmm& ones) { + vpmaddubsw(sum2, x, op); + vpmaddwd(sum2, sum2, ones); + vpaddd(sum4, sum4, sum2); + } }; class JitAvx512f : protected JitAvx2 { diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 6ed42fa92..b60a81d89 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -29,7 +29,7 @@ using thread_func = std::function; class IThreading { public: - explicit IThreading(int nthreads, bool supportPE) : mThreadNum(nthreads), isSupportPE(supportPE) {} + explicit IThreading(bool supportPE) : isSupportPE(supportPE) {} // equal to "for(int i=begin1;i 1) { #pragma omp parallel @@ -109,11 +107,8 @@ class OMPThreading : public IThreading { class StdThreading : public IThreading { public: using Timer_T = utils::timer; - explicit StdThreading(int nthreads) : IThreading(nthreads, true) { - // printf("Using Std\n"); - cr = &device::CpuRuntime::getInstance(nthreads); - create_threads(); - } + explicit StdThreading() : IThreading(true) { cr = nullptr; } + void parallel_for(const thread_func& func) override { time_per_p = 0; time_per_e = 0; @@ -202,10 +197,8 @@ class StdThreading : public IThreading { stop = 1; for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join(); thdset.clear(); - // printf("stop %d\n", mThreadNum); } void create_threads() { - // printf("create %d\n", mThreadNum); thdset.resize(mThreadNum - 1); stop = 0; GetCPUDevice(); @@ -282,7 +275,7 @@ class StdThreading : public IThreading { class SingleThread : public IThreading { public: - SingleThread() : IThreading(1, false) {} + SingleThread() : IThreading(false) { mThreadNum = 1; } void set_threads(int nthreads) override { assert(0); @@ -783,9 +776,9 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { this->mL2Use += static_cast(blks) * (this->mBlock[1] + this->mStep[0]) * (sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce assert(this->mL2Use <= this->mL2Size - ReservedSize); - assert(this->mBlock[0] > 0); - assert(this->mBlock[1] > 0); - assert(this->mBlock[2] > 0); + assert(this->mBlock[0] >= 0); + assert(this->mBlock[1] >= 0); + assert(this->mBlock[2] >= 0); assert(this->mBlock[2] % _GemmCore_T::KTILE == 0); } @@ -887,14 +880,14 @@ class SchedulerDispatcher { SchedulerDispatcher() = default; ~SchedulerDispatcher() { std::pair PEtime = th_->get_PEtime(); - if (needDispach && int(PEtime.first) > 0 && int(PEtime.second) > 0) + if (needDispatch && int(PEtime.first) > 0 && int(PEtime.second) > 0) cr->adjustPE(Scheduler::gemm_ISA(), PEtime.second / PEtime.first); } SchedulerDispatcher(const IThreading* th, const utils::GemmProblem& problem) { th_ = th; cr = &device::CpuRuntime::getInstance(th->num_threads()); - needDispach = cr->mHybrid && th->is_support_PE(); - if (!needDispach) { + needDispatch = cr->mHybrid && th->is_support_PE(); + if (!needDispatch) { Scheduler_P = std::move(Scheduler({th->num_threads(), problem, {0, 0}, cr->mL2Cache, cr->mL1Cache})); } else { Pcore_num = cr->P_core_num; @@ -902,7 +895,8 @@ class SchedulerDispatcher { utils::GemmProblem problem_P = problem, problem_E = problem; const int N = problem.dims[2]; auto PE_Ratio = cr->getPE(Scheduler::gemm_ISA()); - const int N_offset = utils::padto(N - int(N / (1 + PE_Ratio)), Scheduler::mStep[1]); + int N_offset = utils::padto(N - int(N / (1 + PE_Ratio)), Scheduler::mStep[1]); + N_offset = N_offset <= N ? N_offset : N; problem_P.dims[2] = N_offset; Scheduler_P = std::move(Scheduler({th->num_threads() - cr->E_core_num, problem_P, {0, 0}, cr->mL2Cache_P, cr->mL1Cache_P})); @@ -912,7 +906,7 @@ class SchedulerDispatcher { } void getIndex(ThreadProblem& problem) { - if (!needDispach) { + if (!needDispatch) { Scheduler_P.getIndex(problem); } else { if (problem.tid >= Pcore_num + Ecore_num) { @@ -928,16 +922,16 @@ class SchedulerDispatcher { } void print() { - printf("dispatch to hybrid:%d\n", needDispach); + printf("dispatch to hybrid:%d\n", needDispatch); Scheduler_P.print(); - if (needDispach) Scheduler_E.print(); + if (needDispatch) Scheduler_E.print(); } private: Scheduler Scheduler_P, Scheduler_E; const IThreading* th_; device::CpuRuntime* cr; - bool needDispach = false; + bool needDispatch = false; int Pcore_num = 0, Ecore_num = 0; }; @@ -949,15 +943,16 @@ class SchedulerDispatcher { ~SchedulerDispatcher() {} SchedulerDispatcher(const IThreading* th, const Config2D& config) { device::CpuRuntime& cr = device::CpuRuntime::getInstance(config.threads); - needDispach = cr.mHybrid && th->is_support_PE(); - if (!needDispach) { + needDispatch = cr.mHybrid && th->is_support_PE(); + if (!needDispatch) { Scheduler_P = std::move(Scheduler2D(config)); } else { Pcore_num = cr.P_core_num; Ecore_num = cr.E_core_num; Config2D config_P = config, config_E = config; const int N = config.size[1]; - const int N_offset = utils::padto(N - int(N / (1 + cr.getPE(BTLA_ISA::NoSIMD))), config.step[1]); + const auto pe = cr.getPE(BTLA_ISA::NoSIMD); + const int N_offset = utils::padto(N - int(N / (1 + pe)), config.step[1]); config_P.threads = config.threads - cr.E_core_num; config_P.size[1] = N_offset; Scheduler_P = std::move(Scheduler2D(config_P)); @@ -969,7 +964,7 @@ class SchedulerDispatcher { } void getIndex(ThreadProblem& problem) { - if (!needDispach) { + if (!needDispatch) { Scheduler_P.getIndex(problem); } else { if (problem.tid >= Pcore_num + Ecore_num) { @@ -985,14 +980,14 @@ class SchedulerDispatcher { } void print() { - printf("dispatch to hybrid:%d\n", needDispach); + printf("dispatch to hybrid:%d\n", needDispatch); Scheduler_P.print(); - if (needDispach) Scheduler_E.print(); + if (needDispatch) Scheduler_E.print(); } private: Scheduler2D Scheduler_P, Scheduler_E; - bool needDispach = false; + bool needDispatch = false; int Pcore_num = 0, Ecore_num = 0; }; diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 052dd7f60..5a5bd2a24 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -456,18 +456,17 @@ class WeightKBlockNInteger { auto tmpscales = tmp; auto tmpzeropoints = reinterpret_cast(tmpscales + N * blks); if (scales) { - for (size_t i = 0; i < N * blks; i += 2) { - tmpscales[i] = scales[i] / 16; - tmpscales[i + 1] = scales[i + 1] / 16; + for (size_t i = 0; i < N * blks; i += 1) { + tmpscales[i] = scales[i]; } } if (zero_points) { for (size_t i = 0; i < N; i += 1) { for (size_t ib = 0; ib < blks; ib += 2) { auto tmpzp = *(zero_points + i * blks_padding2 / 2 + ib / 2); - tmpzeropoints[i * blks + ib] = ((tmpzp & 0xf) - 8) << 4; + tmpzeropoints[i * blks + ib] = (tmpzp & 0x0f) - 8; if (ib + 1 < blks) { - tmpzeropoints[i * blks + ib + 1] = (((tmpzp & 0xf0) >> 4) - 8) << 4; + tmpzeropoints[i * blks + ib + 1] = ((tmpzp & 0xf0) >> 4) - 8; } } } @@ -486,8 +485,8 @@ class WeightKBlockNInteger { for (size_t i = thdp.loc[0]; i < thdp.loc[0] + thdp.size[0]; i++) { for (size_t j = thdp.loc[1]; j < thdp.loc[1] + thdp.size[1]; j += 2) { auto src = *(B + i * ldb / 2 + j / 2); - s8ptr[(j + 0) * N + i] = ((src & 0xf) - 8) << 4; - s8ptr[(j + 1) * N + i] = (((src & 0xf0) >> 4) - 8) << 4; + s8ptr[(j + 0) * N + i] = ((src & 0xf) - 8); + s8ptr[(j + 1) * N + i] = (((src & 0xf0) >> 4) - 8); } } } @@ -558,16 +557,10 @@ class WeightKBlockNInteger { static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, parallel::IThreading* threading) { - // TODO(zhe): 1D parallel compress - auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64); - auto col = _GemmCore_T::NTILE * K; - auto row = N / _GemmCore_T::NTILE; - auto pad_64_buf = utils::avector(row * ld_dst, 0); - kernel::wrapper::Memcpy2D::forward(B, pad_64_buf.data(), row, col, col, ld_dst); + auto bit1_offset = size_t(N) * K; auto bit2ptr = reinterpret_cast(dstptr); - auto bit1ptr = reinterpret_cast(dstptr + row * ld_dst / 4); - auto ret = - kernel::wrapper::CompressBit3::forward(pad_64_buf.data(), bit2ptr, bit1ptr, row, col, ld_dst, ld_dst); + auto bit1ptr = reinterpret_cast(dstptr + bit1_offset / 4); + auto ret = kernel::wrapper::CompressBit3::forward(B, bit2ptr, bit1ptr, bit1_offset); assert(ret == BTLA_CODE::Success); } @@ -655,21 +648,6 @@ class WeightKBlockNInteger { return BTLA_CODE::NotSupport; } - static inline BTLA_CODE getKBlockWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { - return getFpKBlockWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); - } - - static inline BTLA_CODE getKBlockWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { - return getFpKBlockWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); - } - - static inline BTLA_CODE getKBlockWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { - return getWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); - } - static inline BTLA_CODE getScale(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; @@ -720,160 +698,67 @@ class WeightKBlockNInteger { } protected: - template - static inline BTLA_CODE getFpKBlockWeight(T** dstptr, int* dststep, int k_size, int n_size, int k_offset, - int n_offset, const Param& _param, void* tmpcache, size_t cachesize) { + template + static inline BTLA_CODE getFpWeight(_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; auto NPad = wptr->mNPad; auto KPad = wptr->mKPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + auto zptr = wptr->template ZPtr(); if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) { + auto internal_n_offset = n_offset + i; + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, + wptr->template SPtr(), wptr->template DQPtr(), k_offset / _GemmCore_T::PACK_ROW, + internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize, + wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize); + } + } else { + auto sptr = wptr->template SPtr(); + kernel::wrapper::DecompressKBlockS4Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, n_offset + i, + wptr->mBlockSize, NPad, tmpcache, cachesize); + } + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + auto sptr = wptr->template SPtr(); int8_t* bit3_ptr = wptr->template WPtr(); - auto elt_offset = - n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + auto ld_dst = _GemmCore_T::NTILE * KPad; auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); - kernel::wrapper::DecompressKBlockS3S8Fp::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, - k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); + kernel::wrapper::DecompressKBlockS3Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, + n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) { + auto sptr = wptr->template SPtr(); int8_t* bit2_ptr = wptr->template WPtr(); auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; - assert(elt_offset % 4 == 0); auto bit2ptr = reinterpret_cast(bit2_ptr + elt_offset / 4); - kernel::wrapper::DecompressKBlockS2S8Fp::template forward( - bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); - } else { - assert(0); + kernel::wrapper::DecompressKBlockS2Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + bit2ptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, + n_offset + i, wptr->mBlockSize, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S8) { + auto sptr = wptr->template SPtr(); + auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + int8_t* bptr = wptr->template WPtr() + elt_offset; + kernel::wrapper::DecompressKBlockS8Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward( + bptr, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, sptr, wptr->SDtype(), zptr, k_offset, n_offset + i, + wptr->mBlockSize, NPad, tmpcache, cachesize); } - } - *dststep = k_size; - return BTLA_CODE::Success; - } - template - static inline BTLA_CODE getFpWeight(_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param, void* tmpcache, size_t cachesize) { - auto wptr = _param.packedW; - auto NPad = wptr->mNPad; - auto KPad = wptr->mKPad; - int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; - for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - auto zptr = wptr->template ZPtr(); - if (wptr->SDtype() == BTLA_DTYPE::F32) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { - int8_t* bit3_ptr = wptr->template WPtr(); - auto elt_offset = - n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); - auto row = NPad / _GemmCore_T::NTILE; - assert(elt_offset % 8 == 0); - auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); - kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, - ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) { - int8_t* bit2_ptr = wptr->template WPtr(); - auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; - assert(elt_offset % 4 == 0); - auto bit2ptr = reinterpret_cast(bit2_ptr + elt_offset / 4); - kernel::wrapper::DecompressKBlockS2Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else { - assert(0); - } - } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { - int8_t* bit3_ptr = wptr->template WPtr(); - auto elt_offset = - n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); - auto row = NPad / _GemmCore_T::NTILE; - assert(elt_offset % 8 == 0); - auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); - kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, - ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S2_CLIP) { - int8_t* bit2_ptr = wptr->template WPtr(); - auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; - assert(elt_offset % 4 == 0); - auto bit2ptr = reinterpret_cast(bit2_ptr + elt_offset / 4); - kernel::wrapper::DecompressKBlockS2Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else { - assert(0); - } - } else if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) { - auto internal_n_offset = n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, - wptr->template SPtr(), wptr->template DQPtr(), k_offset / _GemmCore_T::PACK_ROW, - internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize, - wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize); - } else { - assert(0); - } - } else { + else { assert(0); } } @@ -898,11 +783,15 @@ class WeightKBlockNInteger { auto wptr = _param.packedW; auto KPad = wptr->mKPad; auto bptr = wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; + auto zpptr = wptr->template ZPtr(); int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; + assert(wptr->mDType == BTLA_DTYPE::S4_CLIP); + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - kernel::wrapper::DecompressKBlockS4S8::template forward( - bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); + kernel::wrapper::DecompressKBlockS4S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( + bptr + i * KPad / 2, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), + n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); } *dststep = k_size; return BTLA_CODE::Success; @@ -912,20 +801,21 @@ class WeightKBlockNInteger { const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; int8_t* bit3_ptr = wptr->template WPtr(); + auto zpptr = wptr->template ZPtr(); auto KPad = wptr->mKPad; auto NPad = wptr->mNPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; auto row = NPad / _GemmCore_T::NTILE; - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); - auto base_offset = n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE; + auto ld_dst = _GemmCore_T::NTILE * KPad; + auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - auto elt_offset = base_offset + i * utils::padto(KPad, 128); + auto elt_offset = base_offset + i * KPad; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); - kernel::wrapper::DecompressKBlockS3S8Fp::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, - k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast(tmpcache), cachesize); + kernel::wrapper::DecompressKBlockS3S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( + bit2ptr, bit1ptr, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), + n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); } *dststep = k_size; return BTLA_CODE::Success; @@ -935,17 +825,17 @@ class WeightKBlockNInteger { const Param& _param, void* tmpcache, size_t cachesize) { auto wptr = _param.packedW; int8_t* bit2_ptr = wptr->template WPtr(); + int8_t* zpptr = wptr->template ZPtr(); auto KPad = wptr->mKPad; auto NPad = wptr->mNPad; - int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; auto base_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { auto elt_offset = base_offset + i * KPad; assert(elt_offset % 4 == 0); auto bit2ptr = reinterpret_cast(bit2_ptr + elt_offset / 4); - kernel::wrapper::DecompressKBlockS2S8Fp::template forward( - bit2ptr, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast(tmpcache), - cachesize); + kernel::wrapper::DecompressKBlockS2S8<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE>::template forward( + bit2ptr, wptr->IsAsym() ? zpptr : nullptr, *dstptr + i * k_size, wptr->mBlockSize, wptr->CStep(), + n_offset + i, k_offset, k_size, _GemmCore_T::NTILE, tmpcache, cachesize); } *dststep = k_size; return BTLA_CODE::Success; diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index 7b13adbe9..7f00f9aa0 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -623,11 +623,17 @@ class StorageQuantActivation : public IActivationKBlockBase { mSize = utils::padto(mSize, Alignment); return mSize; } + template inline constexpr QT_T* APtr() { return mQBuf.get(); } + template + inline constexpr size_t ASize() { + return mQBuf.size(); + } + template inline constexpr QT_T* ZPtr() { return mCorrection.mZpBuf.get(); diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 17e24b75e..57c840ebd 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -62,7 +62,9 @@ // Only the ISA you use in your project will be compiled. #ifdef __GNUC__ #define CompileAVX512F() (__GNUC__ >= 6) +#define CompileAVX512VNNI() (__GNUC__ >= 9) #define CompileAVX2() (__GNUC__ >= 5) +#define CompileAVXVNNI() (__GNUC__ >= 11) #define CompileAMX() (__GNUC__ >= 11) #define CompileBF16() (__GNUC__ >= 11) #define CompileFP16() (__GNUC__ >= 13) @@ -72,20 +74,24 @@ #if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER) #define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911) +#define CompileAVX512VNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version #define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900) -#define CompileAMX() 0 -#define CompileBF16() 0 -#define CompileFP16() 0 -#define CompileAMXBF16() 0 -#define CompileAMXINT8() 0 +#define CompileAVXVNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version +#define CompileAMX() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version +#define CompileBF16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version +#define CompileFP16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version +#define CompileAMXBF16() (CompileAMX()) +#define CompileAMXINT8() (CompileAMX()) #endif #if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER) #define CompileAVX512F() defined(__AVX512F__) +#define CompileAVX512VNNI() defined(__AVX512VNNI__) #define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__) -#define CompileAMX() 0 -#define CompileBF16() 0 -#define CompileFP16() 0 +#define CompileAVXVNNI() defined(__AVXVNNI__) +#define CompileAMX() defined(__AMX_TILE__) +#define CompileBF16() defined(__AVX512BF16__) +#define CompileFP16() defined(__AVX512FP16__) #define CompileAMXBF16() (CompileAMX()) #define CompileAMXINT8() (CompileAMX()) #endif @@ -224,26 +230,26 @@ struct fp16 { }; struct bit2x4 { - int8_t a : 2; - int8_t b : 2; - int8_t c : 2; - int8_t d : 2; + uint8_t a : 2; + uint8_t b : 2; + uint8_t c : 2; + uint8_t d : 2; }; struct bit1x8 { - int8_t a : 1; - int8_t b : 1; - int8_t c : 1; - int8_t d : 1; - int8_t e : 1; - int8_t f : 1; - int8_t g : 1; - int8_t h : 1; + uint8_t a : 1; + uint8_t b : 1; + uint8_t c : 1; + uint8_t d : 1; + uint8_t e : 1; + uint8_t f : 1; + uint8_t g : 1; + uint8_t h : 1; }; struct bit4x2 { - int8_t x : 4; - int8_t y : 4; + uint8_t x : 4; + uint8_t y : 4; bit4x2(int8_t v) : x(v), y(v) {} bit4x2() : x(0), y(0) {} }; @@ -253,8 +259,6 @@ struct int4x2 : bit4x2 { int4x2() : bit4x2() {} static int8_t convert(int8_t src) { int32_t dst = src; - dst = dst >= 0 ? dst + 8 : dst - 8; - dst = dst / 16; dst = dst > 7 ? 7 : dst; dst = dst < -8 ? -8 : dst; return static_cast(dst); @@ -292,6 +296,24 @@ struct GemmProblem { } }; +template +struct GemvParamB { + uint8_t *b4ptr = 0, *b2ptr = 0, *b1ptr = 0; + ScaleT* sptr = 0; + int8_t* zpptr = 0; + int nbits = 0; + int ldzp = 0; + int kpad = 0; +}; + +struct GemvParamA { + uint8_t* aptr = 0; + float* sptr = 0; + uint8_t* zpptr = 0; + int lda = 0; + int ldzp = 0; +}; + template inline constexpr BTLA_DTYPE bestla_dtype = std::is_same_v ? BTLA_DTYPE::F64 : std::is_same_v ? BTLA_DTYPE::F32 @@ -682,9 +704,11 @@ class timer_statistics_logger { float min_val, max_val, avg_val; void record() { - min_val = statis.min_val / log_ratio; - max_val = statis.max_val / log_ratio; - avg_val = statis.avg_val / log_ratio; + if (statis.count) { + min_val = statis.min_val / log_ratio; + max_val = statis.max_val / log_ratio; + avg_val = statis.avg_val / log_ratio; + } } private: diff --git a/bestla/bestla/bestla_wrapper.h b/bestla/bestla/bestla_wrapper.h index d300aa03c..548e3d8e0 100644 --- a/bestla/bestla/bestla_wrapper.h +++ b/bestla/bestla/bestla_wrapper.h @@ -22,6 +22,81 @@ namespace bestla { namespace wrapper { +namespace gemv_nbits { +class S4 { + public: + static int constexpr NBits = 4; + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + auto isasym = packedW->IsAsym(); + auto bzptr = packedW->template ZPtr(); + int ld_scaleb = packedW->CStep(); + utils::GemvParamB paramB{ + packedW->template WPtr(), nullptr, nullptr, packedW->template SPtr(), + isasym ? bzptr : nullptr, NBits, ld_scaleb, packedW->mKPad}; + return paramB; + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + paramB.b4ptr += n_offset * paramB.kpad / 2; + paramB.sptr += n_offset; + if (paramB.zpptr) { + paramB.zpptr += n_offset; + } + } +}; + +class S3 { + public: + static int constexpr NBits = 3; + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + auto isasym = packedW->IsAsym(); + auto bzptr = packedW->template ZPtr(); + int ld_scaleb = packedW->CStep(); + auto bwptr = packedW->template WPtr(); + auto bit1_offset = packedW->mNPad * packedW->mKPad / 4; + utils::GemvParamB paramB{ + nullptr, bwptr, bwptr + bit1_offset, packedW->template SPtr(), isasym ? bzptr : nullptr, + NBits, ld_scaleb, packedW->mKPad}; + return paramB; + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + paramB.b2ptr += n_offset * paramB.kpad / 4; + paramB.b1ptr += n_offset * paramB.kpad / 8; + paramB.sptr += n_offset; + if (paramB.zpptr) { + paramB.zpptr += n_offset; + } + } +}; + +class S2 { + public: + static int constexpr NBits = 2; + template + static inline utils::GemvParamB createB(storage::gemm::StorageWeightKBlockNInteger* packedW) { + auto isasym = packedW->IsAsym(); + auto bzptr = packedW->template ZPtr(); + int ld_scaleb = packedW->CStep(); + auto bwptr = packedW->template WPtr(); + utils::GemvParamB paramB{ + nullptr, bwptr, nullptr, packedW->template SPtr(), isasym ? bzptr : nullptr, + NBits, ld_scaleb, packedW->mKPad}; + return paramB; + } + template + static void updateBNStep(utils::GemvParamB& paramB, int n_offset) { + paramB.b2ptr += n_offset * paramB.kpad / 4; + paramB.sptr += n_offset; + if (paramB.zpptr) { + paramB.zpptr += n_offset; + } + } +}; +} // namespace gemv_nbits + namespace gemm { template class _PrologueA_T, template class _PrologueB_T, template class _Epilogue_T> @@ -50,7 +125,163 @@ class LauncherBase { PrologueB mProB; Epilogue mEpilogue; + class GEMVWrapper { + public: + static constexpr bool support() { + if constexpr (!std::is_same_v>) { + return false; + } + if constexpr (!std::is_same_v> && + !std::is_same_v> && + !std::is_same_v>) { + return false; + } + + if constexpr (GemmCore::ISA == BTLA_ISA::AVX2) { +#if CompileAVX2() + static_assert(GemmCore::PACK_ROW == 1); + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_FP32) { + return true; + } +#endif + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX512F) { +#if CompileAVX512F() + static_assert(GemmCore::PACK_ROW == 1); + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_FP32) { + return true; + } +#endif + } + return false; + } + static int constexpr MaxGemvM = 4; + static bool implemented(const Param& _param) { + bool impl = true; + impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP; + if constexpr (support()) { + impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 || + _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::BF16; + } + + impl &= _param.problem.dims[1] <= MaxGemvM; + return impl; + } + + template + static void gemv_kblock(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if constexpr (support()) { + auto constexpr TmpSize = 3 * 1024LL; + auto constexpr CSize = 1 * 1024LL; + auto StackTmp_ = alloca(TmpSize + CSize); + auto StackTmp = utils::cpu_pointer_align(StackTmp_); + auto tmpc_ptr = reinterpret_cast((char*)StackTmp + TmpSize); + utils::GemvParamB paramB = SNbits::template createB(_param.paramB.packedW); + const float* Aptr = _param.paramA.A; + if constexpr (std::is_same_v>) { + if (_param.paramA.reordered && _param.paramA.reordered->template APtr()) { + Aptr = _param.paramA.reordered->template APtr(); + } + } + int m = _param.problem.dims[1]; + int n = _param.problem.dims[2]; + int k = _param.problem.dims[3]; + int kblocksize = _param.problem.dims[4]; + auto Cptr = _param.paramC.C + _config.loc[1]; + SNbits::template updateBNStep(paramB, _config.loc[1]); + int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE); + int in = 0; + for (; in < size_padded; in += GemmCore::NTILE) { + if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + Aptr, _param.paramA.lda, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize); + } + + Cptr += GemmCore::NTILE; + SNbits::template updateBNStep(paramB, GemmCore::NTILE); + } + if (size_padded != _config.size[1]) { + if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_fp32_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + Aptr, _param.paramA.lda, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + } + for (int i = 0; i < MTILE; i++) { + memcpy(Cptr + i * _param.paramC.ldc, tmpc_ptr + i * GemmCore::NTILE, + (_config.size[1] - in) * sizeof(CType)); + } + } + Epilogue::forward(_param.paramC.C + _config.loc[1], _param.paramC.ldc, 0, _config.loc[1], MTILE, + _config.size[1], _param.paramC, StackTmp, TmpSize); + } + } + + static void gemv(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if constexpr (support()) { + assert(_param.problem.dims[4] > 0); + auto& m = _param.problem.dims[1]; + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + } + } + }; + void run(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if (GEMVWrapper::support() && GEMVWrapper::implemented(_param)) { + GEMVWrapper::gemv(_param, _config); + } else { + gemm(_param, _config); + } + } + + protected: + void gemm(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { mGemmCore.configure(_config.size[0], _config.size[1], _param.problem.dims[3]); auto StackTmp = alloca(_config.stacksize); auto tmpB = reinterpret_cast(StackTmp); @@ -70,7 +301,6 @@ class LauncherBase { } } - protected: void run_block(const Param& _param, const parallel::gemm::ThreadProblemBase& _config, int blk_m, int blk_n, int blk_msize, int blk_nsize, AType* tmpA, BType* tmpB, CType* tmpC, void* tmpcache) { int n_padded = utils::padto(blk_nsize, GemmCore::NTILE); @@ -115,22 +345,19 @@ class LauncherBase { }; template class _PrologueA_T, - template class _PrologueB_T, template class _BlockEpilogue_T, - template class _Epilogue_T> -class LauncherKBlock { + template class _PrologueB_T, template class _Epilogue_T> +class LauncherIntKBlock { public: using GemmCore = _GemmCore_T; static constexpr BTLA_ISA ISA = _RT_ISA_T; using PrologueA = _PrologueA_T; using PrologueB = _PrologueB_T; using Epilogue = _Epilogue_T<_RT_ISA_T>; - using BlockEpilogue = _BlockEpilogue_T<_RT_ISA_T>; using AType = typename GemmCore::AType; using AParam = typename PrologueA::Param; using BType = typename GemmCore::BType; using BParam = typename PrologueB::Param; using CType = typename GemmCore::CType; - using BEpiParam = typename BlockEpilogue::Param; using EpiParam = typename Epilogue::Param; using AccType = float; static_assert(GemmCore::ISA <= _RT_ISA_T, "RunTime ISA should cover GEMM's ISA"); @@ -138,178 +365,184 @@ class LauncherKBlock { const utils::GemmProblem problem; const AParam paramA; const BParam paramB; - const BEpiParam paramBlk; const EpiParam paramC; }; _GemmCore_T mGemmCore; PrologueA mProA; PrologueB mProB; - BlockEpilogue mBlockEpi; Epilogue mEpilogue; - void run(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { - mGemmCore.configure(_config.size[0], _config.size[1], _param.problem.dims[3]); - auto StackTmp = alloca(_config.stacksize); - auto tmpB = reinterpret_cast(StackTmp); - tmpB = utils::cpu_pointer_align(tmpB); - auto tmpA = reinterpret_cast(tmpB + static_cast(_config.block[1]) * _config.block[2]); - tmpA = utils::cpu_pointer_align(tmpA); - auto tmpC = reinterpret_cast(tmpA + static_cast(GemmCore::MTILE) * _config.block[2]); - tmpC = utils::cpu_pointer_align(tmpC); - auto tmpBlk = reinterpret_cast(tmpC + static_cast(_config.block[0]) * _config.block[1]); - tmpBlk = utils::cpu_pointer_align(tmpBlk); - auto tmpCache = reinterpret_cast(tmpBlk + static_cast(_config.block[0]) * _config.block[1]); - tmpCache = utils::cpu_pointer_align(tmpCache); - for (int itern = 0; itern < _config.size[1]; itern += _config.block[1]) { - int n_remain = utils::remainsize(itern, _config.size[1], _config.block[1]); - for (int iterm = 0; iterm < _config.size[0]; iterm += _config.block[0]) { - int m_remain = utils::remainsize(iterm, _config.size[0], _config.block[0]); - std::memset(tmpC, 0, _config.block[0] * _config.block[1] * sizeof(AccType)); - auto& KBlock = _param.problem.dims[4]; - if (KBlock <= _config.block[2]) { - run_block(_param, _config, iterm, itern, m_remain, n_remain, tmpA, tmpB, tmpBlk, tmpC, tmpCache); - } else { - run_block_large(_param, _config, iterm, itern, m_remain, n_remain, tmpA, tmpB, tmpBlk, tmpC, tmpCache); + class GEMVWrapper { + public: + static constexpr bool support() { + if constexpr (!std::is_same_v>) { + return false; + } + if constexpr (!std::is_same_v> && + !std::is_same_v>) { + return false; + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX_VNNI) { +#if CompileAVXVNNI() + static_assert(GemmCore::PACK_ROW == 4); + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + return true; + } + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + return true; } +#endif } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX2) { +#if CompileAVX2() + static_assert(GemmCore::PACK_ROW == 4); + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + return true; + } + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + return true; + } +#endif + } + if constexpr (GemmCore::ISA == BTLA_ISA::AVX512_VNNI || GemmCore::ISA == BTLA_ISA::AMX_INT8) { +#if CompileAVX512VNNI() + static_assert(GemmCore::PACK_ROW == 4); + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_US_FP32) { + return true; + } + if constexpr (GemmCore::COMP == bestla::gemm::CompType::COMP_INT8_SS_FP32) { + return true; + } +#endif + } + return false; } - } + static int constexpr MaxGemvM = 4; - protected: - void run_block(const Param& _param, const parallel::gemm::ThreadProblemBase& _config, int blk_m, int blk_n, - int blk_msize, int blk_nsize, AType* tmpA, BType* tmpB, CType* tmpBlk, AccType* tmpC, void* tmpcache) { - int n_padded = utils::padto(blk_nsize, GemmCore::NTILE); - auto& K = _param.problem.dims[3]; - auto& KBlock = _param.problem.dims[4]; - for (int iterk = 0; iterk < K; iterk += _config.block[2]) { - int k_remain = utils::remainsize(iterk, K, _config.block[2]); - int k_padded = utils::padto(k_remain, GemmCore::KTILE); - auto bptr_cache = tmpB; - int bcache_step = 0; - mProB.getKBlockWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterk, _config.loc[1] + blk_n, _param.paramB, - tmpcache, _config.tmpcachesize); - int bcache_stride = bcache_step * sizeof(BType); + static bool implemented(const Param& _param) { + bool impl = true; + impl &= _param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP || + _param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP; + impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 || + _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::BF16; + impl &= _param.problem.dims[1] <= MaxGemvM; + return impl; + } - for (int ikk = 0; ikk < k_remain; ikk += KBlock) { - int k_remain1 = utils::remainsize(iterk + ikk, K, KBlock); - int k_paddedle1 = utils::padto_le(k_remain1, GemmCore::KTILE); - for (int i = 0; i < blk_msize; i += GemmCore::MTILE) { - int m_remain = utils::remainsize(i, blk_msize, GemmCore::MTILE); - auto cptr_cache = tmpBlk + i * _config.block[1]; - int ccache_stride = _config.block[1] * sizeof(CType); - if (k_paddedle1) { - AType* aptr_cache = tmpA; - int acache_step = 0; - mProA.getActivation(&aptr_cache, &acache_step, _param.paramA, m_remain, k_paddedle1, - (blk_m + i + _config.loc[0]), iterk + ikk, tmpcache, _config.tmpcachesize); - mGemmCore.forward(aptr_cache, bptr_cache + ikk * GemmCore::NTILE, cptr_cache, m_remain, n_padded, - k_paddedle1, acache_step * sizeof(AType), bcache_stride, ccache_stride, 0, tmpcache, - _config.tmpcachesize); + template + static void gemv_kblock(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if constexpr (support()) { + auto constexpr TmpSize = 3 * 1024LL; + auto constexpr CSize = 1 * 1024LL; + auto StackTmp_ = alloca(TmpSize + CSize); + auto StackTmp = utils::cpu_pointer_align(StackTmp_); + auto tmpc_ptr = reinterpret_cast((char*)StackTmp + TmpSize); + utils::GemvParamB paramB = SNbits::template createB(_param.paramB.packedW); + utils::GemvParamA paramA{ + _param.paramA.quan->template APtr(), _param.paramA.quan->template SPtr(), + _param.paramA.quan->template ZPtr(), _param.paramA.quan->mKPad, _param.paramA.quan->CStep()}; + + int m = _param.problem.dims[1]; + int n = _param.problem.dims[2]; + int k = _param.problem.dims[3]; + int kblocksize = _param.problem.dims[4]; + auto Cptr = _param.paramC.C + _config.loc[1]; + SNbits::template updateBNStep(paramB, _config.loc[1]); + int size_padded = utils::padto_le(_config.size[1], GemmCore::NTILE); + int in = 0; + for (; in < size_padded; in += GemmCore::NTILE) { + if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_u8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize); + } else if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, Cptr, _param.paramC.ldc, k, kblocksize, StackTmp, TmpSize); + } + + Cptr += GemmCore::NTILE; + SNbits::template updateBNStep(paramB, GemmCore::NTILE); + } + if (size_padded != _config.size[1]) { + if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_u8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); + } else if constexpr (std::is_same_v) { + kernel::wrapper::GEMVWoqNBits::forward_s8s8_fp32<_RT_ISA_T, ScaleT, GemmCore::NTILE, MTILE>( + paramA, paramB, tmpc_ptr, GemmCore::NTILE, k, kblocksize, StackTmp, TmpSize); } - int k_tail = k_remain1 - k_paddedle1; - if (k_tail) { - AType* aptr_cache = tmpA; - int acache_step = 0; - mProA.getActivation(&aptr_cache, &acache_step, _param.paramA, m_remain, k_tail, - (blk_m + i + _config.loc[0]), iterk + ikk + k_paddedle1, tmpcache, - _config.tmpcachesize); - mGemmCore.forward(aptr_cache, bptr_cache + (ikk + k_paddedle1) * GemmCore::NTILE, cptr_cache, m_remain, - n_padded, k_tail, acache_step * sizeof(AType), bcache_stride, ccache_stride, - 0 + k_paddedle1, tmpcache, _config.tmpcachesize); + for (int i = 0; i < MTILE; i++) { + memcpy(Cptr + i * _param.paramC.ldc, tmpc_ptr + i * GemmCore::NTILE, + (_config.size[1] - in) * sizeof(CType)); } } - mBlockEpi.forward(tmpBlk, tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, - (iterk + ikk) / KBlock, blk_msize, blk_nsize, _param.paramBlk, tmpcache, - _config.tmpcachesize); + Epilogue::forward(_param.paramC.C + _config.loc[1], _param.paramC.ldc, 0, _config.loc[1], MTILE, + _config.size[1], _param.paramC, StackTmp, TmpSize); } } - auto cachewithblk = _config.tmpcachesize + static_cast(_config.block[0]) * _config.block[1] * sizeof(CType); - mEpilogue.forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, blk_nsize, - _param.paramC, tmpBlk, cachewithblk); - } - void run_block_large(const Param& _param, const parallel::gemm::ThreadProblemBase& _config, int blk_m, int blk_n, - int blk_msize, int blk_nsize, AType* tmpA, BType* tmpB, CType* tmpBlk, AccType* tmpC, - void* tmpcache) { - int n_padded = utils::padto(blk_nsize, GemmCore::NTILE); - auto& K = _param.problem.dims[3]; - auto KBlock = _param.problem.dims[4]; - assert(K % KBlock == 0); - for (int iterk = 0; iterk < K; iterk += KBlock) { - memset(tmpBlk, 0, sizeof(CType) * blk_msize * _config.block[1]); - for (int iblkk = 0; iblkk < KBlock; iblkk += _config.block[2]) { - int k_remain = utils::remainsize(iterk + iblkk, iterk + KBlock, _config.block[2]); - int k_padded = utils::padto(k_remain, GemmCore::KTILE); - int k_paddedle = utils::padto_le(k_remain, GemmCore::KTILE); - auto bptr_cache = tmpB; - int bcache_step = 0; - mProB.getKBlockWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterk + iblkk, _config.loc[1] + blk_n, - _param.paramB, tmpcache, _config.tmpcachesize); - int bcache_stride = bcache_step * sizeof(BType); - for (int i = 0; i < blk_msize; i += GemmCore::MTILE) { - int m_remain = utils::remainsize(i, blk_msize, GemmCore::MTILE); - auto cptr_cache = tmpBlk + i * _config.block[1]; - int ccache_stride = _config.block[1] * sizeof(CType); - if (k_paddedle) { - AType* aptr_cache = tmpA; - int acache_step = 0; - mProA.getActivation(&aptr_cache, &acache_step, _param.paramA, m_remain, k_paddedle, - (blk_m + i + _config.loc[0]), iterk + iblkk, tmpcache, _config.tmpcachesize); - mGemmCore.forward(aptr_cache, bptr_cache, cptr_cache, m_remain, n_padded, k_paddedle, - acache_step * sizeof(AType), bcache_stride, ccache_stride, iblkk, tmpcache, - _config.tmpcachesize); + static void gemv(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if constexpr (support()) { + auto& m = _param.problem.dims[1]; + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S4_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); } - int k_tail = k_remain - k_paddedle; - if (k_tail) { - AType* aptr_cache = tmpA; - int acache_step = 0; - mProA.getActivation(&aptr_cache, &acache_step, _param.paramA, m_remain, k_tail, - (blk_m + i + _config.loc[0]), iterk + k_paddedle + iblkk, tmpcache, - _config.tmpcachesize); - mGemmCore.forward(aptr_cache, bptr_cache + k_paddedle * GemmCore::NTILE, cptr_cache, m_remain, n_padded, - k_tail, acache_step * sizeof(AType), bcache_stride, ccache_stride, iblkk + k_paddedle, - tmpcache, _config.tmpcachesize); + return; + } + + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S3_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } + return; + } + if (_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP) { + if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F32) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); + } else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::BF16) { + if (m == 1) gemv_kblock(_param, _config); + if (m == 2) gemv_kblock(_param, _config); + if (m == 3) gemv_kblock(_param, _config); + if (m == 4) gemv_kblock(_param, _config); } + return; } } - mBlockEpi.forward(tmpBlk, tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, - iterk / KBlock, blk_msize, blk_nsize, _param.paramBlk, tmpcache, _config.tmpcachesize); } - auto cachewithblk = _config.tmpcachesize + static_cast(_config.block[0]) * _config.block[1] * sizeof(CType); - mEpilogue.forward(tmpC, _config.block[1], (_config.loc[0] + blk_m), _config.loc[1] + blk_n, blk_msize, blk_nsize, - _param.paramC, tmpBlk, cachewithblk); - } -}; - -template class _PrologueA_T, - template class _PrologueB_T, template class _Epilogue_T> -class LauncherIntKBlock { - public: - using GemmCore = _GemmCore_T; - static constexpr BTLA_ISA ISA = _RT_ISA_T; - using PrologueA = _PrologueA_T; - using PrologueB = _PrologueB_T; - using Epilogue = _Epilogue_T<_RT_ISA_T>; - using AType = typename GemmCore::AType; - using AParam = typename PrologueA::Param; - using BType = typename GemmCore::BType; - using BParam = typename PrologueB::Param; - using CType = typename GemmCore::CType; - using EpiParam = typename Epilogue::Param; - using AccType = float; - static_assert(GemmCore::ISA <= _RT_ISA_T, "RunTime ISA should cover GEMM's ISA"); - struct Param { - const utils::GemmProblem problem; - const AParam paramA; - const BParam paramB; - const EpiParam paramC; }; - _GemmCore_T mGemmCore; - PrologueA mProA; - PrologueB mProB; - Epilogue mEpilogue; void run(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { + if (GEMVWrapper::support() && GEMVWrapper::implemented(_param)) { + GEMVWrapper::gemv(_param, _config); + } else { + gemm(_param, _config); + } + } + + protected: + void gemm(const Param& _param, const parallel::gemm::ThreadProblemBase& _config) { mGemmCore.configure(_config.size[0], _config.size[1], _param.problem.dims[3]); auto StackTmp = alloca(_config.stacksize); auto tmpB = reinterpret_cast(StackTmp); @@ -334,7 +567,6 @@ class LauncherIntKBlock { } } - protected: // _config.block[2]%kblock==0 // _config.block[2]>=kblock void run_block(const Param& _param, const parallel::gemm::ThreadProblemBase& _config, int blk_m, int blk_n, @@ -371,8 +603,8 @@ class LauncherIntKBlock { int ldsb_cache = tmp_ldsb; auto scaleB_cache = scaleB; auto reduceB_cache = reduceB; - mProB.getKBlockWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterk, _config.loc[1] + blk_n, _param.paramB, - tmp_, _config.tmpcachesize); + mProB.getWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterk, _config.loc[1] + blk_n, _param.paramB, tmp_, + _config.tmpcachesize); mProB.getScale(&scaleB_cache, &ldsb_cache, k_padded, n_padded, iterk, _config.loc[1] + blk_n, _param.paramB, tmp_, _config.tmpcachesize); mProB.getReduce(&reduceB_cache, &ldsb_cache, k_padded, n_padded, iterk, _config.loc[1] + blk_n, _param.paramB, @@ -437,8 +669,8 @@ class LauncherIntKBlock { int ldsb_cache = tmp_ldsb; auto scaleB_cache = scaleB; auto reduceB_cache = reduceB; - mProB.getKBlockWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterkk, _config.loc[1] + blk_n, - _param.paramB, tmp_, _config.tmpcachesize); + mProB.getWeight(&bptr_cache, &bcache_step, k_padded, n_padded, iterkk, _config.loc[1] + blk_n, _param.paramB, + tmp_, _config.tmpcachesize); mProB.getScale(&scaleB_cache, &ldsb_cache, k_padded, n_padded, iterkk, _config.loc[1] + blk_n, _param.paramB, tmp_, _config.tmpcachesize); mProB.getReduce(&reduceB_cache, &ldsb_cache, k_padded, n_padded, iterkk, _config.loc[1] + blk_n, _param.paramB, diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index e980fa90a..8856010b6 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -29,42 +29,35 @@ namespace avx2 { #elif defined(ICX) #pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function) #endif -template -static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) { + +static inline __m256i unpack_4bits(void* srcptr, __m256i mask) { auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr)); auto ymm0 = _mm256_cvtepu8_epi16(raw_data); - auto ymm1 = _mm256_slli_epi16(ymm0, 8); - ymm0 = _mm256_slli_epi16(ymm0, 4); + auto ymm1 = _mm256_slli_epi16(ymm0, 4); ymm0 = _mm256_or_si256(ymm0, ymm1); ymm0 = _mm256_and_si256(ymm0, mask); - if constexpr (LowBits) { - ymm0 = _mm256_srli_epi16(ymm0, 4); - } return ymm0; } -template -static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { - static_assert(N % 2 == 0); - static_assert(N <= 64); - if constexpr (N == 32) { - auto dst0 = unpack_4bits_avx2(srcptr, mask); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); - } else if constexpr (N > 32) { - auto dst0 = unpack_4bits_avx2(srcptr, mask); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); - int8_t temp[32]; - memcpy(temp, srcptr + 16, (N - 32) / 2); - dst0 = unpack_4bits_avx2(temp, mask); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); - memcpy(dstptr + 32, temp, (N - 32)); - } else { - int8_t temp[32]; - memcpy(temp, srcptr, N / 2); - auto dst0 = unpack_4bits_avx2(temp, mask); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); - memcpy(dstptr, temp, N); - } +static inline __m256i unpack_2bits(utils::bit2x4* ptr, const __m256i& vshift_y, const __m256i& vmask0_y, + const __m256i& vsfhl_mask_y, const __m256i& vorder_y) { + auto vraw_x = _mm_loadl_epi64((const __m128i*)ptr); + auto vsrc_y = _mm256_broadcastq_epi64(vraw_x); + auto vordered_y = _mm256_permutevar8x32_epi32(vsrc_y, vorder_y); + auto vs_y = _mm256_srlv_epi32(vordered_y, vshift_y); + auto v2_y = _mm256_and_si256(vs_y, vmask0_y); + auto vout_y = _mm256_shuffle_epi8(v2_y, vsfhl_mask_y); + return vout_y; +} + +static inline __m256i unpack_1bits(utils::bit1x8* ptr, const __m256i& bit1Shift_1, const __m256i& bit1Mask, + const __m256i& bit1Shift_2, const __m256i& highMask) { + auto bit1x32 = _mm256_set1_epi32(*(int*)ptr); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + return bit1x32; } inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) { @@ -73,48 +66,130 @@ inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) { } inline __m128i ymm_cvtepi32_epi16(__m256i src) { - __m128i tmp; -#if defined(__GNUC__) || defined(__clang_major__) - for (size_t i = 0; i < 8; i++) { - (reinterpret_cast(&tmp))[i] = (reinterpret_cast(&src))[i]; - } -#else - for (size_t i = 0; i < 8; i++) { - tmp.m128i_i16[i] = src.m256i_i32[i]; - } -#endif - return tmp; + const auto shuffle_mask_32_to_16 = _mm256_set_epi8(13, 12, 9, 8, 5, 4, 1, 0, 13, 12, 9, 8, 5, 4, 1, 0, 13, 12, 9, 8, + 5, 4, 1, 0, 13, 12, 9, 8, 5, 4, 1, 0); + __m256i trunc_elements = _mm256_shuffle_epi8(src, shuffle_mask_32_to_16); + __m256i ordered = _mm256_permute4x64_epi64(trunc_elements, 0x58); + __m128i result = _mm256_castsi256_si128(ordered); + return result; } -inline __m128i ymm_cvt_fp32_bf16(__m256 vfp32) { +inline __m128i ymm_cvt_fp32_bf16(const __m256& vfp32) { return ymm_cvtepi32_epi16(_mm256_bsrli_epi128(_mm256_castps_si256(vfp32), 2)); } -template -static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { +static inline __m256i load_s8_s32(int8_t* srcptr) { auto xmm = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); auto ymm = _mm256_cvtepi8_epi32(xmm); - auto ymm1 = _mm256_cvtepi32_ps(ymm); + return ymm; +} + +static inline __m256 load_bf16_fp32(const utils::bf16* srcptr) { + auto tmp = _mm_loadu_si128(reinterpret_cast(srcptr)); + auto vf32 = ymm_cvt_bf16_fp32(tmp); + return vf32; +} + +template +static inline __m256 load_T_fp32(const T* srcptr) { + __m256 vtmp; + if constexpr (std::is_same_v) { + vtmp = _mm256_loadu_ps(srcptr); + } else if constexpr (std::is_same_v) { + vtmp = load_bf16_fp32(srcptr); + } else { + static_assert(std::is_same_v || std::is_same_v); + } + return vtmp; +} + +static inline __m256 load_s8_fp32(int8_t* srcptr) { + auto src_y = load_s8_s32(srcptr); + auto dst_y = _mm256_cvtepi32_ps(src_y); + return dst_y; +} + +template +static inline void store_fp_T(const __m256& src_y, T* dstptr) { if constexpr (std::is_same_v) { - auto xmm = ymm_cvt_fp32_bf16(ymm1); + auto xmm = ymm_cvt_fp32_bf16(src_y); _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), xmm); + } else if constexpr (std::is_same_v) { + _mm256_storeu_ps(dstptr, src_y); } else { - _mm256_storeu_ps(dstptr, ymm1); + assert(0); } } -template -static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { +template +static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { + auto src_fp_y = load_s8_fp32(srcptr); + store_fp_T(src_fp_y, dstptr); +} + +template +static inline __m256 dequant_s8_fp(int8_t* srcptr, __m256 vscales, __m256i vzps = __m256i()) { + auto src_s32_y = load_s8_s32(srcptr); + if constexpr (IsAsym) src_s32_y = _mm256_sub_epi32(src_s32_y, vzps); + auto src_fp_y = _mm256_cvtepi32_ps(src_s32_y); + src_fp_y = _mm256_mul_ps(src_fp_y, vscales); + return src_fp_y; +} + +template +static inline void dequant_s8_N_avx2(DstT* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { static_assert(N % 8 == 0); int constexpr VLoop = N / 8; for (int iv = 0; iv < VLoop; iv += 1) { - auto src_s8 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); - auto zmm = _mm256_cvtepi8_epi32(src_s8); - if constexpr (!_IS_SYM) zmm = _mm256_sub_epi32(zmm, vzps[iv]); - auto fzmm = _mm256_cvtepi32_ps(zmm); - fzmm = _mm256_mul_ps(fzmm, vscales[iv]); - _mm256_storeu_ps(dstptr + iv * 8, fzmm); + __m256 dq_f32_y; + if constexpr (IsAsym) { + dq_f32_y = dequant_s8_fp(srcptr, vscales[iv], vzps[iv]); + } else { + dq_f32_y = dequant_s8_fp(srcptr, vscales[iv]); + } + store_fp_T(dq_f32_y, dstptr + iv * 8); + } +} + +static inline __m256i load_zp_epi8_broadcast_epi16_v16(int8_t* zpptr, const __m256i& vindex) { + auto v_zp_x = _mm_loadu_si128((const __m128i*)zpptr); + auto v_zp_y = _mm256_cvtepi8_epi16(v_zp_x); + auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex); + return v_zp_y_cast; +} + +static inline __m256i load_zp_epi8_broadcast_epi16(int8_t* zpptr, const __m256i& vindex) { + auto v_zp_x = _mm_loadu_si128((const __m128i*)zpptr); + auto v_zp_y = _mm256_cvtepi8_epi16(v_zp_x); + auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex); + return v_zp_y_cast; +} + +static inline __m256i load_zp_epi8_broadcast_epi32(int8_t* zpptr, const __m256i& vindex) { + auto v_zp_x = _mm_loadl_epi64((const __m128i*)zpptr); + auto v_zp_y = _mm256_cvtepi8_epi32(v_zp_x); + auto v_zp_y_cast = _mm256_shuffle_epi8(v_zp_y, vindex); + return v_zp_y_cast; +} + +// vout= {vsrc.f32[0],vsrc.f32[0],...,vsrc.f32[4],vsrc.f32[4]} +template +static inline __m256 broadcast_ps_1_2(__m256 vsrc_y, const __m256i& vshuf_index_y) { + __m256 tmp; + if constexpr (LowBits) { + tmp = _mm256_permute2f128_ps(vsrc_y, vsrc_y, 0); + } else { + tmp = _mm256_permute2f128_ps(vsrc_y, vsrc_y, 17); } + auto tmpi = _mm256_castps_si256(tmp); + + auto out = _mm256_shuffle_epi8(tmpi, vshuf_index_y); + return _mm256_castsi256_ps(out); +} + +template +static inline __m256i broadcast_epi32_1_2(__m256i vsrc_y, const __m256i& vshuf_index_y) { + return _mm256_castps_si256(broadcast_ps_1_2(_mm256_castsi256_ps(vsrc_y), vshuf_index_y)); } inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, @@ -199,57 +274,6 @@ static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr return BTLA_CODE::Success; } -template -BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - const int Vlen = 8; - size_t simd_process_num = utils::padto_le(col, Vlen); - auto packrow4_permute_idx = _mm256_setr_epi32(0, 0, 0, 0, 1, 1, 1, 1); - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - int j = 0; - for (; j < simd_process_num; j += Vlen) { - auto s8_ymm_v = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)); - auto s32_ymm_v = _mm256_cvtepi8_epi32(s8_ymm_v); - if constexpr (WITH_ZP) { - auto zp_ymm = - _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + kpos * NPad + j / PACK_ROW))); - if constexpr (PACK_ROW == 4) zp_ymm = _mm256_permutevar8x32_epi32(zp_ymm, packrow4_permute_idx); - s32_ymm_v = _mm256_sub_epi32(s32_ymm_v, zp_ymm); - } - auto f32_ymm_v = _mm256_cvtepi32_ps(s32_ymm_v); - auto scale_ymm = _mm256_loadu_ps(sptr + j / PACK_ROW); - if constexpr (PACK_ROW == 4) scale_ymm = _mm256_permutevar8x32_ps(scale_ymm, packrow4_permute_idx); - f32_ymm_v = _mm256_mul_ps(f32_ymm_v, scale_ymm); - if constexpr (std::is_same_v<_DST_T, float>) { - _mm256_storeu_ps(dstptr + i * ld_dst + j, f32_ymm_v); - } else if constexpr (std::is_same_v<_DST_T, utils::bf16>) { - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr + i * ld_dst), ymm_cvt_fp32_bf16(f32_ymm_v)); - } else { - assert(0); - } - } - for (; j < col; j++) { - float tmp = (float)(srcptr[i * ld_src + j]); - if constexpr (WITH_ZP) tmp -= (float)(zero_points[kpos * NPad + j / PACK_ROW]); - dstptr[i * ld_dst + j] = tmp * sptr[j / PACK_ROW]; - } - } - return BTLA_CODE::Success; -} - -template -static inline BTLA_CODE dequant_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - if (zero_points == nullptr) - return dequant_kblock_s8_fp_fwd(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad); - else - return dequant_kblock_s8_fp_fwd(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad); -} - template static inline BTLA_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep, const int row, const int col, const float* scaleA, const int ldsa, @@ -365,52 +389,672 @@ static inline BTLA_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst) { - uint32_t mask = 0xf0f0f0f0; +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack4_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 4; + __m256i v_zp_y[NReg]; + uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - if (col == ld_src) { - size_t elesize = static_cast(row) * col; - size_t velt = utils::padto_le(elesize, 32); - size_t i = 0; - for (; i < velt; i += 32) { - convert_s4_s8_N_avx2<32, S4_T>(dstptr + i, reinterpret_cast(srcptr + i / 2), vmask); + auto vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 8, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); } - for (; i < elesize; i += 2) { - auto tmp = srcptr[i / 2]; - dstptr[i + 0] = kernel::ref::get_s8(tmp.x); - dstptr[i + 1] = kernel::ref::get_s8(tmp.y); + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } } - return BTLA_CODE::Success; } - return BTLA_CODE::NotSupport; + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { - uint32_t mask = 0xf0f0f0f0; +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack2_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m256i v_zp_y[NReg]; + uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); - if (col == ld_src) { + auto vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16_v16(tmp + i * 16, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpout = tmp + Unroll * PackRow * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(tmp + i * 16, vmask); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack1_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 32; + __m256i v_zp_y[UnpackLoop]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 16, vmask); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpout = tmp + Unroll * NTILE / 2; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(tmp + i * 16, vmask); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack4_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 4; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(2); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 8, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack2_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(2); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16_v16(tmp + i * 16, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * PackRow * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmp + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack1_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 32; + __m256i v_zp_y[UnpackLoop]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(2); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE / 4; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmp + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t elesize, int8_t* tmp, + size_t tmpsize) { + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + size_t velt = utils::padto_le(elesize, 32); + size_t i = 0; + auto vbias = _mm256_set1_epi8(8); + for (; i < velt; i += 32) { + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm256_sub_epi8(vout_y, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); + } + if (velt < elesize) { + if (elesize >= 32) { + i = elesize - 32; + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm256_sub_epi8(vout_y, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout_y); + } else { + ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0); + } + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 4) { + func = &decompress_kblock_s4_s8_pack4_row; + } + if constexpr (PackRow == 1) { + func = &decompress_kblock_s4_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s4_s8_pack2_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(srcptr + head_size * NTILE / 2, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, + head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { size_t elesize = static_cast(row) * col; + return decompress_s4_s8(srcptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} - size_t velt = utils::padto_le(elesize, 32); - size_t i = 0; - assert(tmpsize >= 32); - for (; i < velt; i += 32) { - convert_s4_s8_N_avx2<32, S4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); - convert_s8_fp_v8(dstptr + i, tmp); - convert_s8_fp_v8(dstptr + i + 8, tmp + 8); - convert_s8_fp_v8(dstptr + i + 16, tmp + 16); - convert_s8_fp_v8(dstptr + i + 24, tmp + 24); +static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr, size_t unpack_elt, int8_t* tmp, + size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(2); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= 32) { + i = unpack_elt - 32; + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + ref::decompress_s2_s8(bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); } - for (; i < elesize; i += 2) { - auto tmp = srcptr[i / 2]; - dstptr[i + 0] = static_cast<_DST_T>(static_cast(ref::get_s8(tmp.x))); - dstptr[i + 1] = static_cast<_DST_T>(static_cast(ref::get_s8(tmp.y))); + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, + int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, + size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 4) { + func = &decompress_kblock_s2_s8_pack4_row; + } + if constexpr (PackRow == 1) { + func = &decompress_kblock_s2_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s2_s8_pack2_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit2ptr + head_size * NTILE / 4, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, + head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } } - return BTLA_CODE::Success; + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s2_s8(bit2ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(4); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= 32) { + i = unpack_elt - 32; + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vout = _mm256_or_si256(vout, vb1); + vout = _mm256_sub_epi8(vout, vbias); + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + ref::decompress_s3_s8(bit2ptr + i / 4, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack4_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 4; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(4); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 8, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack2_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m256i v_zp_y[NReg]; + const auto vindex = _mm256_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(4); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16_v16(tmp + i * 16, vindex); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb2ptr = tmp; + memcpy(tmpb2ptr, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 4), bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack1_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + static_assert((NTILE % 8) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 32; + __m256i v_zp_y[UnpackLoop]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vbias = _mm256_set1_epi8(4); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + v_zp_y[i] = _mm256_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 8, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 4, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(dstptr + i * 32 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb2ptr = tmp; + memcpy(tmpb2ptr, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 8), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 4), bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + v_s8_y = _mm256_or_si256(v_s8_y, vb1); + v_s8_y = _mm256_sub_epi8(v_s8_y, v_zp_y[i]); + _mm256_storeu_si256((__m256i*)(tmpout + i * 32), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s3_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s3_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s3_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit2ptr, bit1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit2ptr + head_size * NTILE / 4, bit1ptr + head_size * NTILE / 8, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s3_s8(bit2ptr, bit1ptr, dstptr, elesize, tmp, tmpsize); } return BTLA_CODE::Success; } @@ -480,28 +1124,6 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { - if (col == ld_src) { - size_t elesize = (size_t)row * col; - size_t ele64 = utils::padto_le(elesize, 64); - size_t i = 0; - if (i + 64 <= ele64) { - for (; i < ele64; i += 64) { - for (size_t j = 0; j < 64; j += 8) { - convert_s8_fp_v8(dstptr + i + j, srcptr + i + j); - } - } - } - for (; i < elesize; i += 1) { - auto tmp = srcptr[i]; - dstptr[i] = static_cast(static_cast(tmp)); - } - return BTLA_CODE::Success; - } - return BTLA_CODE::NotSupport; -} - template static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, const int N) { @@ -565,10 +1187,47 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, } } +template +static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { + static_assert(N % 2 == 0); + static_assert(N <= 64); + const auto vbias = _mm256_set1_epi8(8); + if constexpr (N == 32) { + auto dst0 = unpack_4bits(srcptr, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + } else if constexpr (N > 32) { + auto dst0 = unpack_4bits(srcptr, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + int8_t temp[32]; + memcpy(temp, srcptr + 16, (N - 32) / 2); + dst0 = unpack_4bits(temp, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr + 32, temp, (N - 32)); + } else { + int8_t temp[32]; + memcpy(temp, srcptr, N / 2); + auto dst0 = unpack_4bits(temp, mask); + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dst0 = _mm256_sub_epi8(dst0, vbias); + } + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr, temp, N); + } +} + template inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst, int8_t* tmp, size_t tmpsize) { - uint32_t mask = 0xf0f0f0f0; + uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); float* LUT; static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, @@ -606,7 +1265,7 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _ int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, int8_t* tmpbuf, size_t tmpsize) { - uint32_t mask = 0xf0f0f0f0; + uint32_t mask = 0x0f0f0f0f; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); float* LUT = nullptr; if constexpr (QT_T == BTLA_DTYPE::F4_BNB) { @@ -721,38 +1380,335 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ return BTLA_CODE::NotSupport; } -template -static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, int8_t* tmp, size_t tmpsize) { - auto ret = BTLA_CODE::NotSupport; - if constexpr (_PACK_ROW == 1 && std::is_same_v<_DST_T, float> && std::is_same_v<_ST, float>) { - if (zero_points == nullptr) { - if (col == 24) { - ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); - } else if (col == 48) { - ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); +template +inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + if (zero_points == nullptr) { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m256 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8, vscale_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8); + } + } + } else if constexpr (PackRow == 4) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else if constexpr (PackRow == 2) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } } else { assert(0); } - - } else { - if (col == 24) { - ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); - } else if (col == 48) { - ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); + } + return BTLA_CODE::Success; + } else { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m256 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * 8); + } + __m256i vzp_y[NReg]; + for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * 8); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8, vscale_y[i], vzp_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8); + } + } + } else if constexpr (PackRow == 4) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + __m256i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_y); + + auto tmp = load_s8_s32(zero_points + ele_off + i * 8); + auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } + } else if constexpr (PackRow == 2) { + const auto vshuf_index_y = _mm256_set_epi8(15, 14, 13, 12, 15, 14, 13, 12, 11, 10, 9, 8, 11, 10, 9, 8, 7, 6, 5, + 4, 7, 6, 5, 4, 3, 2, 1, 0, 3, 2, 1, 0); + __m256 vscale_y[PackRow * NReg]; + __m256i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m256 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm256_loadu_ps(sptr + i * 8); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * 8); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_y); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_y); + auto tmp = load_s8_s32(zero_points + ele_off + i * 8); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_y); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_y); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * 8 * PackRow + ip * 8, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * 8 * PackRow + ip * 8); + } + } + } } else { assert(0); } } + return BTLA_CODE::Success; + } +} + +template +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s3_fp_row( + b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; } return ret; } @@ -1177,7 +2133,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); auto res = _mm256_add_epi8(bit1x32, bit2x32); - res = _mm256_slli_epi32(res, 5); + res = _mm256_sub_epi8(res, highMask); _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); } }; @@ -1230,6 +2186,54 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr return BTLA_CODE::Success; } +template +inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int unpack_elt, int8_t* tmp, + size_t tmpsize) { + int constexpr VBits = 256; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + if (std::is_same_v<_DST_T, int8_t>) { + _mm256_storeu_si256((__m256i*)(dstptr + i), vout); + } else { + _mm256_storeu_si256((__m256i*)tmp, vout); + for (int j = 0; j < VElt; j++) { + dstptr[i + j] = tmp[j]; + } + } + } + ref::decompress_kblock_s2_s8fp(bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_bit2_packrow_fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int row, int col, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { + auto unpack_elt = row * col; + decompress_kblock_s2_s8fp<_S2_T>(bit2ptr, dstptr, unpack_elt, reinterpret_cast(tmp), tmpsize); + // TODO(zhe): simd version + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + for (int j = 0; j < col; j++) { + float tmp = static_cast(dstptr[i * col + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); + dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); + } + } + + return BTLA_CODE::Success; +} + inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) { const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2; static const auto mask_exp = _mm256_set1_epi32(0x7f800000); @@ -1320,10 +2324,1536 @@ constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_ load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>, load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>, load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>}; + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc, __m256* vsca) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va = _mm256_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + ftmp = _mm256_mul_ps(ftmp, vsca[i]); + vacc[i] = _mm256_fmadd_ps(va, ftmp, vacc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + ftmp = _mm256_mul_ps(ftmp, vsca[i]); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); + } + vacc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); + } + } + } + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m256* vacc_loc) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va = _mm256_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + vacc_loc[i] = _mm256_fmadd_ps(va, ftmp, vacc_loc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m256 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 8 + ikk * NReg * 8); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm256_set1_ps(*(Aptr + ikk + im * lda)); + } + vacc_loc[im * NReg + i] = _mm256_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); + } + } + } + } +} + +template +static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& b4ptr = B.b4ptr; + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm256_set1_epi8(8); + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + auto vbias = _mm256_set1_epi8(2); + + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } + + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + auto vbias = _mm256_set1_epi8(4); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m256 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm256_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm256_loadu_si256((const __m256i*)(tmp + i * 32)); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + b1ptr += 8 * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + _mm256_storeu_si256((__m256i*)(tmp + 32 * i), vb); + b2ptr += 8 * Unroll / 4; + b1ptr += 8 * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } + + __m256 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * 8); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm256_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +static inline __m256i _mm256_dpbusd_avx2_epi32(__m256i& c, const __m256i& a, const __m256i& b) { + const __m256i dot2 = _mm256_maddubs_epi16(a, b); + const __m256i ones = _mm256_set1_epi16(1); + const __m256i sum4 = _mm256_madd_epi16(ones, dot2); + return _mm256_add_epi32(c, sum4); +} + +template +static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m256i* iacc, + __m256* facc) { + __m256 v_a_scale[MTILE]; + for (int im = 0; im < MTILE; im++) { + v_a_scale[im] = _mm256_set1_ps(*(asptr + im * ldzp)); + } + + for (int i = 0; i < NReg; i++) { + __m256 v_b_scale = load_T_fp32(bsptr + i * 8); + for (int im = 0; im < MTILE; im++) { + auto vtmp = _mm256_mul_ps(v_a_scale[im], v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[im * NReg + i]); + facc[im * NReg + i] = _mm256_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); + } + } +} + +template +static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m256i* iacc, __m256i* bacc) { + if constexpr (MReg == 1) { + auto zp = int(azptr[0]); + __m256i v_a_zp = _mm256_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm256_mullo_epi32(v_a_zp, bacc[in]); + iacc[in] = _mm256_sub_epi32(iacc[in], vtmp); + } + } else { + __m256i v_a_zp[MReg]; + for (int im = 0; im < MReg; im++) { + auto zp = int(azptr[im * ldzp]); + v_a_zp[im] = _mm256_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm256_mullo_epi32(v_a_zp[im], bacc[in]); + iacc[im * NReg + in] = _mm256_sub_epi32(iacc[im * NReg + in], vtmp); + } + } + } +} + +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + auto& azptr = A.zpptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(2); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx2_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx2_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx2_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +namespace vnni { + +#if CompileAVXVNNI() +#ifdef __GNUC__ +#pragma GCC push_options +#pragma GCC target("avxvnni") +#endif + +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + auto& azptr = A.zpptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(8); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + const __m256i vbias = _mm256_set1_epi8(8); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = kernel::avx2::unpack_4bits((void*)(b4ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm256_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + auto azptr = A.zpptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + // Initialize accumulator with zeros + __m256 acc[NReg]; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int constexpr TotalElt = UnpackElt * NTILE * KTILE; + int constexpr Loop128 = TotalElt / 128; + int8_t UnpackBuf[TotalElt]; + for (int i = 0; i < NReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + int32_t* bit1_ptr = reinterpret_cast(src2); + for (int i = 0; i < 4; i++) { + auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + + auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); + auto res = _mm256_add_epi8(bit1x32, bit2x32); + res = _mm256_sub_epi8(res, highMask); + _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); + } + }; + assert(azptr); + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg; i++) { + iacc[i] = _mm256_setzero_si256(); + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * ld_scaleb; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + } + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + for (int i = 0; i < NReg; i++) { + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + vb = _mm256_sub_epi8(vb, bzp[i]); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + } + } + a8ptr += KTILE * UnpackElt; + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + for (int i = 0; i < NReg; i++) { + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + } + } + a8ptr += KTILE * UnpackElt; + } + } + + const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); + auto zp = int(azptr[ib]); + const __m256i v_a_zp = _mm256_set1_epi32(zp); + auto bsptr = B.sptr + ib * ld_scaleb; + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_mullo_epi32(v_a_zp, bacc[i]); + iacc[i] = _mm256_sub_epi32(iacc[i], bacc[i]); + __m256 v_b_scale; + if constexpr (std::is_same_v) { + v_b_scale = _mm256_loadu_ps(bsptr + i * 8); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); + v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); + } + v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[i]); + acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); + } + } + + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8, acc[i]); + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + auto azptr = A.zpptr; + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + // Initialize accumulator with zeros + __m256 acc[NReg]; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int constexpr TotalElt = UnpackElt * NTILE * KTILE; + int constexpr Loop128 = TotalElt / 128; + int8_t UnpackBuf[TotalElt]; + for (int i = 0; i < NReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + int32_t* bit1_ptr = reinterpret_cast(src2); + for (int i = 0; i < 4; i++) { + auto bit1x32 = _mm256_set1_epi32(bit1_ptr[i]); + bit1x32 = _mm256_srlv_epi32(bit1x32, bit1Shift_1); + bit1x32 = _mm256_and_si256(bit1x32, bit1Mask); + bit1x32 = _mm256_mullo_epi32(bit1x32, bit1Shift_2); + bit1x32 = _mm256_and_si256(highMask, bit1x32); + + auto bit2x32 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2 * i)); + auto res = _mm256_add_epi8(bit1x32, bit2x32); + res = _mm256_slli_epi32(res, 5); + _mm256_storeu_si256((__m256i*)(dst + 32 * i), res); + } + }; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg]; + for (int i = 0; i < NReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * ld_scaleb; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + } + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + auto vabsa = _mm256_sign_epi8(va, va); + for (int i = 0; i < NReg; i++) { + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + vb = _mm256_sub_epi8(vb, bzp[i]); + vb = _mm256_sign_epi8(vb, va); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); + } + } + a8ptr += KTILE * UnpackElt; + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + for (int il = 0; il < Loop128; il++) { + bit3_interleave_decompress_pack128(b2ptr, b1ptr, UnpackBuf + il * 128); + b2ptr += 128 / 4; + b1ptr += 128 / 8; + } + for (int iu = 0; iu < UnpackElt; iu++) { + auto va = _mm256_set1_epi32(*(int*)(a8ptr + iu * KTILE)); + auto vabsa = _mm256_sign_epi8(va, va); + for (int i = 0; i < NReg; i++) { + auto vb = _mm256_loadu_si256((const __m256i*)(UnpackBuf + iu * NTILE * KTILE + i * 32)); + vb = _mm256_sign_epi8(vb, va); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], vabsa, vb); + } + } + a8ptr += KTILE * UnpackElt; + } + } + + const __m256 v_a_scale = _mm256_set1_ps(*(asptr + ib)); + auto bsptr = B.sptr + ib * ld_scaleb; + for (int i = 0; i < NReg; i++) { + __m256 v_b_scale; + if constexpr (std::is_same_v) { + v_b_scale = _mm256_loadu_ps(bsptr + i * 8); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128((const __m128i*)(bsptr + i * 8)); + v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); + } + v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[i]); + acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); + } + } + + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8, acc[i]); + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(2); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(2); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(vbias, bzp[i]); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += 8 * KTILE / 4; + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + vb = _mm256_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += 8 * KTILE / 4; + } + } + } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + __m256i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m256i va = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } else { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + bacc[i] = _mm256_dpbusd_avx_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr NReg = NTILE / 8; + int constexpr MReg = MTILE; + __m256 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0_y = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0); + auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0); + const __m256i onesu8 = _mm256_set1_epi8(1); + const __m256i vbias = _mm256_set1_epi8(4); + const auto vindex = _mm256_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0); + + const __m256i highMask = _mm256_set1_epi8(0x04); + const __m256i bit1Mask = _mm256_set1_epi32(0x0F); + const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0); + const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2)); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m256i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + if (B.zpptr) { + __m256i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 8, vindex); + bzp[i] = _mm256_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m256i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm256_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0_y, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, bit1Shift_1, bit1Mask, bit1Shift_2, highMask); + vb = _mm256_or_si256(vb, vb1); + vb = _mm256_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm256_sign_epi8(vb, va[j]); + auto vabsa = _mm256_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm256_dpbusd_avx_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += 8 * KTILE / 4; + b1ptr += 8 * KTILE / 8; + } + } + } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(C + i * 8 + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +#ifdef __GNUC__ +#pragma GCC pop_options +#else +#endif +#endif +} // namespace vnni + #ifdef __GNUC__ #pragma GCC pop_options #else diff --git a/bestla/bestla/kernel_avx512_bf16.h b/bestla/bestla/kernel_avx512_bf16.h index ece55a5dd..c1ca028a3 100644 --- a/bestla/bestla/kernel_avx512_bf16.h +++ b/bestla/bestla/kernel_avx512_bf16.h @@ -20,9 +20,11 @@ namespace bestla { namespace kernel { namespace avx512_bf16 { #if CompileBF16() +#if defined(__GNUC__) #pragma GCC push_options #pragma GCC target("avx512bf16", "avx512vl", "avx512bw") #endif +#endif static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, float* dst_ptr, int row, int col, int src_step, int dst_step, bool zeropadding) { #if CompileBF16() @@ -36,13 +38,12 @@ static inline BTLA_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr, auto dst = dst_ptr + i * dst_step; int j = 0; for (; j < col_body; j += simd_proc_elt) - _mm512_storeu_ps( - dst + j, // - reinterpret_cast<__m512>(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2))); + _mm512_storeu_ps(dst + j, // + _mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2))); if (col_tail > 0) _mm512_mask_storeu_ps( dst + j, tail_mask, - reinterpret_cast<__m512>(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2))); + _mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2))); if (zeropadding && npadding) std::memset(dst + col, 0, npadding); } return BTLA_CODE::Success; @@ -92,8 +93,10 @@ static inline BTLA_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void return BTLA_CODE::NotSupport; } #if CompileBF16() +#if defined(__GNUC__) #pragma GCC pop_options #endif +#endif } // namespace avx512_bf16 } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 822c4457c..41c82bbcf 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -54,7 +54,39 @@ inline __m256i zmm_cvt_fp32_bf16(__m512 vfp32) { #endif } -static inline __m512i unpack_4bits(__m256i v4bits, __m512i vmask) { +static inline __m512 load_bf16_fp32(const utils::bf16* srcptr) { + auto tmp = _mm256_loadu_si256(reinterpret_cast(srcptr)); + auto vf32 = zmm_cvt_bf16_fp32(tmp); + return vf32; +} + +static inline __m512i unpack_4bits(void* srcptr, __m512i mask) { + auto raw_data = _mm256_loadu_si256(reinterpret_cast<__m256i*>(srcptr)); + auto ymm0 = _mm512_cvtepu8_epi16(raw_data); + auto ymm1 = _mm512_slli_epi16(ymm0, 4); + ymm0 = _mm512_or_si512(ymm0, ymm1); + ymm0 = _mm512_and_si512(ymm0, mask); + return ymm0; +} + +static inline __m512i unpack_2bits(utils::bit2x4* ptr, const __m512i& vshift_y, const __m512i& vmask0_y, + const __m512i& vsfhl_mask_y, const __m512i& vorder_y) { + auto vraw_x = _mm_loadu_si128((const __m128i*)ptr); + auto vsrc_y = _mm512_broadcast_i64x2(vraw_x); + auto vordered_y = _mm512_permutex2var_epi32(vsrc_y, vorder_y, vsrc_y); + auto vs_y = _mm512_srlv_epi32(vordered_y, vshift_y); + auto v2_y = _mm512_and_si512(vs_y, vmask0_y); + auto vout_y = _mm512_shuffle_epi8(v2_y, vsfhl_mask_y); + return vout_y; +} + +static inline __m512i unpack_1bits(utils::bit1x8* ptr, const __m512i& zmm_0x00, const __m512i& zmm_0x04) { + auto bit1_mask1 = _cvtu64_mask64(*(uint64_t*)ptr); + auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04); + return zmm1_; +} + +static inline __m512i unpack_4bits_high(__m256i v4bits, __m512i vmask) { auto ymm1 = _mm256_slli_epi32(v4bits, 4); auto zmm = _mm512_cvtepi8_epi16(v4bits); auto zmm1 = _mm512_cvtepi8_epi16(ymm1); @@ -66,14 +98,14 @@ static inline __m512i unpack_4bits(__m256i v4bits, __m512i vmask) { static inline void convert_s4_s8_highbits(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto ymm = _mm256_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); - auto zmm = unpack_4bits(ymm, vmask); + auto zmm = unpack_4bits_high(ymm, vmask); _mm512_mask_storeu_epi64(dstptr, __mmask8(LoadMask), zmm); } static inline void convert_s4_s8_highbits_v32(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto xmm = _mm_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); auto ymm = _mm256_castsi128_si256(xmm); - auto zmm = unpack_4bits(ymm, vmask); + auto zmm = unpack_4bits_high(ymm, vmask); auto ymm_out = _mm512_castsi512_si256(zmm); _mm256_mask_storeu_epi64(dstptr, __mmask8(LoadMask), ymm_out); } @@ -114,6 +146,33 @@ static inline void dequant_s8_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, } } +static inline __m512i load_s8_s32(int8_t* srcptr) { + auto xmm = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr)); + auto ymm = _mm512_cvtepi8_epi32(xmm); + return ymm; +} + +template +static inline __m512 dequant_s8_fp(int8_t* srcptr, __m512 vscales, __m512i vzps = __m512i()) { + auto src_s32_y = load_s8_s32(srcptr); + if constexpr (IsAsym) src_s32_y = _mm512_sub_epi32(src_s32_y, vzps); + auto src_fp_y = _mm512_cvtepi32_ps(src_s32_y); + src_fp_y = _mm512_mul_ps(src_fp_y, vscales); + return src_fp_y; +} + +template +static inline void store_fp_T(__m512 src_y, T* dstptr) { + if constexpr (std::is_same_v) { + auto xmm = zmm_cvt_fp32_bf16(src_y); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), xmm); + } else if constexpr (std::is_same_v) { + _mm512_storeu_ps(dstptr, src_y); + } else { + assert(false); + } +} + template static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, __m512i* vzps = nullptr) { static_assert(N % 16 == 0); @@ -200,6 +259,23 @@ static inline void vec_broadcast_ps_1_2(__m512* dst2regs, __m512* src1regs, __m5 dst2regs[1] = _mm512_castsi512_ps(_mm512_unpackhi_epi32(tmpreg, tmpreg)); } +template +static inline __m512 broadcast_ps_1_2(__m512 vsrc_y, const __m512i& vshuf_index_high, const __m512i& vshuf_index_low) { + __m512 tmp; + if constexpr (LowBits) { + tmp = _mm512_permutex2var_ps(vsrc_y, vshuf_index_low, vsrc_y); + } else { + tmp = _mm512_permutex2var_ps(vsrc_y, vshuf_index_high, vsrc_y); + } + return tmp; +} + +template +static inline __m512i broadcast_epi32_1_2(__m512i vsrc_y, const __m512i& vshuf_index_high, + const __m512i& vshuf_index_low) { + return _mm512_castps_si512(broadcast_ps_1_2(_mm512_castsi512_ps(vsrc_y), vshuf_index_high, vshuf_index_low)); +} + static inline void vec_broadcast_epi32_1_2(__m512i* dst2regs, __m512i* src1regs, __m512i idxreg) { auto tmpreg = _mm512_permutexvar_epi64(idxreg, src1regs[0]); dst2regs[0] = _mm512_unpacklo_epi32(tmpreg, tmpreg); @@ -588,9 +664,9 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int } template -static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_s4_fp_Dep(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, + int kblock, int NPad, int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { if (zero_points == nullptr) { return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>( @@ -745,38 +821,6 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst return BTLA_CODE::NotSupport; } -template -static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst) { - uint32_t mask = 0xf0f0f0f0; - auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); - if (col == ld_src) { - size_t elesize = (size_t)row * col; - size_t ele256 = utils::padto_le(elesize, 256); - size_t ele64 = utils::padto_le(elesize, 64); - size_t i = 0; - constexpr int LoadMask64 = (1 << (64 / 8)) - 1; - for (; i < ele256; i += 256) { - convert_s4_s8_highbits(dstptr + i + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); - convert_s4_s8_highbits(dstptr + i + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); - convert_s4_s8_highbits(dstptr + i + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); - convert_s4_s8_highbits(dstptr + i + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); - } - if (i + 64 <= ele64) { - for (; i < ele64; i += 64) { - convert_s4_s8_highbits(dstptr + i, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); - } - } - for (; i < elesize; i += 2) { - auto tmp = srcptr[i / 2]; - dstptr[i + 0] = kernel::ref::get_s8(tmp.x); - dstptr[i + 1] = kernel::ref::get_s8(tmp.y); - } - return BTLA_CODE::Success; - } - return BTLA_CODE::NotSupport; -} - static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int blocksize) { int constexpr VLen = 16; @@ -1434,67 +1478,6 @@ static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { - uint32_t mask = 0xf0f0f0f0; - auto zmm_mask = _mm512_set1_epi32(*reinterpret_cast(&mask)); - if (col == ld_src) { - size_t elesize = (size_t)row * col; - size_t ele256 = utils::padto_le(elesize, 256); - size_t ele64 = utils::padto_le(elesize, 64); - assert(tmpsize >= 256); - size_t i = 0; - constexpr int LoadMask64 = (1 << (64 / 8)) - 1; - for (; i < ele256; i += 256) { - convert_s4_s8_highbits(tmp + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); - convert_s4_s8_highbits(tmp + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); - convert_s4_s8_highbits(tmp + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); - convert_s4_s8_highbits(tmp + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); - for (size_t j = 0; j < 256; j += 16) { - convert_s8_fp_v16(dstptr + i + j, tmp + j); - } - } - if (i + 64 <= ele64) { - for (; i < ele64; i += 64) { - convert_s4_s8_highbits(tmp, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); - for (size_t j = 0; j < 64; j += 16) { - convert_s8_fp_v16(dstptr + i + j, tmp + j); - } - } - } - for (; i < elesize; i += 2) { - auto tmp = srcptr[i / 2]; - dstptr[i + 0] = static_cast<_DST_T>(static_cast(kernel::ref::get_s8(tmp.x))); - dstptr[i + 1] = static_cast<_DST_T>(static_cast(kernel::ref::get_s8(tmp.y))); - } - return BTLA_CODE::Success; - } - return BTLA_CODE::NotSupport; -} - -template -inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { - if (col == ld_src) { - size_t elesize = (size_t)row * col; - size_t ele64 = utils::padto_le(elesize, 64); - size_t i = 0; - if (i + 64 <= ele64) { - for (; i < ele64; i += 64) { - for (size_t j = 0; j < 64; j += 16) { - convert_s8_fp_v16(dstptr + i + j, srcptr + i + j); - } - } - } - for (; i < elesize; i += 1) { - auto tmp = srcptr[i]; - dstptr[i] = static_cast(static_cast(tmp)); - } - return BTLA_CODE::Success; - } - return BTLA_CODE::NotSupport; -} - template static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* srcptr, const int srcstep, float* dstptr, const int dststep, const int M, const int N) { @@ -2433,6 +2416,2001 @@ inline __m512 exp_ps_0_1(const __m512 x) { return poly_scale_2nd_ps(z, f, c0, c1, c2); } +static inline __m512i load_zp_epi8_broadcast_epi16(int8_t* zpptr, const __m512i& vindex) { + auto v_zp_x = _mm256_loadu_si256((const __m256i*)zpptr); + auto v_zp_y = _mm512_cvtepi8_epi16(v_zp_x); + auto v_zp_y_cast = _mm512_shuffle_epi8(v_zp_y, vindex); // TODO(Yu) AVX512F only + return v_zp_y_cast; +} + +static inline __m512i load_zp_epi8_broadcast_epi32(int8_t* zpptr, const __m512i& vindex) { + auto v_zp_x = _mm_loadu_si128((const __m128i*)zpptr); + auto v_zp_y = _mm512_cvtepi8_epi32(v_zp_x); + auto v_zp_y_cast = _mm512_shuffle_epi8(v_zp_y, vindex); // TODO(Yu) AVX512F only + return v_zp_y_cast; +} + +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t elesize, int8_t* tmp, + size_t tmpsize) { + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + size_t velt = utils::padto_le(elesize, 64); + size_t i = 0; + auto vbias = _mm512_set1_epi8(8); + for (; i < velt; i += 64) { + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm512_sub_epi8(vout_y, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout_y); + } + if (velt < elesize) { + if (elesize >= 64) { + i = elesize - 64; + auto vout_y = unpack_4bits(reinterpret_cast(srcptr + i / 2), vmask); + vout_y = _mm512_sub_epi8(vout_y, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout_y); + } else { + ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack1_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 16; + static_assert((NTILE % 16) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + int constexpr UnpackLoop = Unroll * NTILE / 64; + __m512i v_zp_y[UnpackLoop]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < UnpackLoop; i++) { + v_zp_y[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpout = tmp + Unroll * NTILE / 2; + for (int i = 0; i < UnpackLoop; i++) { + auto v_s8_y = unpack_4bits(tmp + i * 32, vmask); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack2_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 16; + static_assert((NTILE % 16) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m512i v_zp_y[NReg]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, + 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16(tmp + i * 32, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 2, k_tail * NTILE / 2); + auto tmpout = tmp + Unroll * PackRow * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(tmp + i * 32, vmask); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s4_s8_pack4_row(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 16; + static_assert((NTILE % 16) == 0); + int constexpr PackRow = 4; + __m512i v_zp_y[NReg]; + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 16, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b4ptr = srcptr + (ir + ib) * NTILE / 2; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_4bits(b4ptr + i * 32, vmask); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s4_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s4_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s4_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(srcptr + head_size * NTILE / 2, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, + head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s4_s8(srcptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr, size_t unpack_elt, int8_t* tmp, + size_t tmpsize) { + int constexpr VBits = 512; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= VElt) { + i = unpack_elt - VElt; + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } else { + ref::decompress_s2_s8(bit2ptr + i / 4, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack4_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 4; + __m512i v_zp_y[NReg]; + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 16, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack2_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 2; + int constexpr Unroll = 2; + __m512i v_zp_y[NReg]; + const auto vindex = _mm512_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, + 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16(tmp + i * 32, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * PackRow * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmp + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8_pack1_row(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, + int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + memcpy(tmp, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpout = tmp + Unroll * NTILE / 4; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmp + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, + int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, + size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s2_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s2_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s2_s8_pack4_row; + } + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit2ptr + head_size * NTILE / 4, zpptr, dstptr + head_size * NTILE, blocksize, ldzp, n_offset, + head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s2_s8(bit2ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, + size_t unpack_elt, int8_t* tmp, size_t tmpsize) { + int constexpr VBits = 512; + int constexpr VElt = VBits / 8; + int i = 0; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + int elt_pad = utils::padto_le(unpack_elt, VElt); + for (; i < elt_pad; i += VElt) { + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } + if (elt_pad < unpack_elt) { + if (unpack_elt >= VElt) { + i = unpack_elt - VElt; + auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(bit1ptr + i / 8, zmm_0x00, zmm_0x04); + vout = _mm512_or_si512(vout, vb1); + vout = _mm512_sub_epi8(vout, vbias); + _mm512_storeu_si512((__m512i*)(dstptr + i), vout); + } else { + ref::decompress_s3_s8(bit2ptr + i / 4, bit1ptr + i / 8, dstptr + i, unpack_elt - i, tmp, tmpsize); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack1_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, zptr, NTILE * sizeof(int8_t)); + } + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb2ptr = tmp; + memcpy(tmpb2ptr, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 8), zmm_0x00, zmm_0x04); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack2_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 1; + int constexpr Unroll = 4; + __m512i v_zp_y[NReg]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + + const auto vindex = _mm512_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, + 8, 6, 6, 4, 4, 2, 2, 0, 0, 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0, + 14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + memcpy(tmp, zptr, NTILE * sizeof(int8_t)); + memcpy(tmp + NTILE, zptr, NTILE * sizeof(int8_t)); + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi16(tmp + i * 32, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + int k_remain_unrll = utils::padto_le(k_remain, PackRow * Unroll); + int ib = 0; + for (; ib < k_remain_unrll; ib += PackRow * Unroll) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + int k_tail = k_remain - k_remain_unrll; + if (k_tail > 0) { + auto tmpb2ptr = tmp; + memcpy(tmpb2ptr, srcptr + (ir + ib) * NTILE / 4, k_tail * NTILE / 4); + auto tmpb1ptr = tmp + Unroll * NTILE / 2; + memcpy(tmpb1ptr, bit1ptr + (ir + ib) * NTILE / 8, k_tail * NTILE / 8); + auto tmpout = tmp + Unroll * NTILE; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits((utils::bit2x4*)(tmpb2ptr + i * 16), vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits((utils::bit1x8*)(tmpb1ptr + i * 8), zmm_0x00, zmm_0x04); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(tmpout + i * 64), v_s8_y); + } + memcpy(dstptr + (ir + ib) * NTILE, tmpout, k_tail * NTILE); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8_pack4_row(utils::bit2x4* srcptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, + int k_offset, int row, int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + static_assert((NTILE % VLen) == 0); + int constexpr PackRow = 4; + __m512i v_zp_y[NReg]; + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ir = 0; ir < row; ir += blocksize) { + auto zptr = zpptr + (k_offset + ir) / blocksize * ldzp + n_offset; + for (int i = 0; i < NReg; i++) { + v_zp_y[i] = load_zp_epi8_broadcast_epi32(zptr + i * 16, vindex); + v_zp_y[i] = _mm512_add_epi8(v_zp_y[i], vbias); + } + int k_remain = utils::remainsize(ir, row, blocksize); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b2ptr = srcptr + (ir + ib) * NTILE / 4; + auto b1ptr = bit1ptr + (ir + ib) * NTILE / 8; + for (int i = 0; i < NReg; i++) { + auto v_s8_y = unpack_2bits(b2ptr + i * 16, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr + i * 8, zmm_0x00, zmm_0x04); + v_s8_y = _mm512_or_si512(v_s8_y, vb1); + v_s8_y = _mm512_sub_epi8(v_s8_y, v_zp_y[i]); + _mm512_storeu_si512((__m512i*)(dstptr + i * 64 + (ir + ib) * NTILE), v_s8_y); + } + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + if (zpptr) { + typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp, + size_t tmpsize); + decompfunc func = nullptr; + if (col == NTILE) { + if constexpr (PackRow == 1) { + func = &decompress_kblock_s3_s8_pack1_row; + } + if constexpr (PackRow == 2) { + func = &decompress_kblock_s3_s8_pack2_row; + } + if constexpr (PackRow == 4) { + func = &decompress_kblock_s3_s8_pack4_row; + } + + if (func) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + (*func)(bit2ptr, bit1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, head_size, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + (*func)(bit2ptr + head_size * NTILE / 4, bit1ptr + head_size * NTILE / 8, zpptr, dstptr + head_size * NTILE, + blocksize, ldzp, n_offset, head_end, body_size, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + } + assert(0); + return BTLA_CODE::NotSupport; + } else { + size_t elesize = static_cast(row) * col; + return decompress_s3_s8(bit2ptr, bit1ptr, dstptr, elesize, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +template +inline BTLA_CODE decompress_kblock_s8_fp_row(int8_t* srcptr, DST_T* dstptr, int row, void* scales_, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int blocksize, int ldzp, + int8_t* tmp, size_t tmpsize) { + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + const auto vshuf_index_low = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + const auto vshuf_index_high = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + if (zero_points == nullptr) { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m512 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); + } + } + } else if constexpr (PackRow == 4) { + __m512 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else if constexpr (PackRow == 2) { + __m512 vscale_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } else { + for (int ir = 0; ir < row; ir += blocksize) { + int k_remain = utils::remainsize(ir, row, blocksize); + int ele_off = (k_offset + ir) / blocksize * ldzp + n_offset; + if constexpr (PackRow == 1) { + __m512 vscale_y[NReg]; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + for (int i = 0; i < NReg; i++) vscale_y[i] = load_bf16_fp32(sptr + i * VLen); + } + __m512i vzp_y[NReg]; + for (int i = 0; i < NReg; i++) vzp_y[i] = load_s8_s32(zero_points + ele_off + i * VLen); + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen, vscale_y[i], vzp_y[i]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen); + } + } + } else if constexpr (PackRow == 4) { + __m512 vscale_y[PackRow * NReg]; + __m512i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } else { + assert(0); + } + auto vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vcast_y = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 2] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 3] = broadcast_ps_1_2(vcast_y, vshuf_index_high, vshuf_index_low); + + auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); + auto vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vcasti_y = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 2] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 3] = broadcast_epi32_1_2(vcasti_y, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else if constexpr (PackRow == 2) { + __m512 vscale_y[PackRow * NReg]; + __m512i vzp_y[PackRow * NReg]; + for (int i = 0; i < NReg; i++) { + __m512 vraw; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + ele_off; + vraw = _mm512_loadu_ps(sptr + i * VLen); + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + ele_off; + vraw = load_bf16_fp32(sptr + i * VLen); + } + vscale_y[i * PackRow + 0] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + vscale_y[i * PackRow + 1] = broadcast_ps_1_2(vraw, vshuf_index_high, vshuf_index_low); + auto tmp = load_s8_s32(zero_points + ele_off + i * VLen); + vzp_y[i * PackRow + 0] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + vzp_y[i * PackRow + 1] = broadcast_epi32_1_2(tmp, vshuf_index_high, vshuf_index_low); + } + for (int ib = 0; ib < k_remain; ib += PackRow) { + auto b8ptr = srcptr + (ir + ib) * NTILE; + for (int i = 0; i < NReg; i++) { + for (int ip = 0; ip < PackRow; ip++) { + auto vdeq_y = dequant_s8_fp(b8ptr + i * VLen * PackRow + ip * VLen, vscale_y[i * PackRow + ip], + vzp_y[i * PackRow + ip]); + store_fp_T(vdeq_y, dstptr + (ir + ib) * NTILE + i * VLen * PackRow + ip * VLen); + } + } + } + } else { + assert(0); + } + } + return BTLA_CODE::Success; + } +} + +template +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s8_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s8_fp_row(srcptr + head_size * NTILE, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} +template +inline BTLA_CODE decompress_kblock_s4_fp_row(utils::int4x2* srcptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s4_fp_row(srcptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s4_fp_row(srcptr + head_size * NTILE / 2, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s2_fp_row(utils::bit2x4* b2ptr, DST_T* dstptr, int row, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, + row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s2_fp_row(b2ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s2_fp_row(b2ptr + head_size * NTILE / 4, dstptr + head_size * NTILE, + body_size, scales_, sdtype, zero_points, head_end, n_offset, + blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +inline BTLA_CODE decompress_kblock_s3_fp_row(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + int constexpr NReg = NTILE / 8; + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + auto ret = decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, + k_offset, row, NTILE, tmp, tmpsize); + assert(ret == BTLA_CODE::Success); + return decompress_kblock_s8_fp_row(tmps8ptr, dstptr, row, scales_, sdtype, nullptr, k_offset, + n_offset, blocksize, ldzp, tmp, tmpsize); +} + +template +inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, int col, + void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, + int n_offset, int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if (col == NTILE) { + int head_end = utils::padto(k_offset, blocksize); + head_end = std::min(head_end, k_offset + row); + int head_size = head_end - k_offset; + if (head_size > 0) { + decompress_kblock_s3_fp_row(b2ptr, b1ptr, dstptr, head_size, scales_, sdtype, zero_points, + k_offset, n_offset, blocksize, ldzp, tmp, tmpsize); + } + int body_size = row - head_size; + if (body_size > 0) { + decompress_kblock_s3_fp_row( + b2ptr + head_size * NTILE / 4, b1ptr + head_size * NTILE / 8, dstptr + head_size * NTILE, body_size, scales_, + sdtype, zero_points, head_end, n_offset, blocksize, ldzp, tmp, tmpsize); + } + return BTLA_CODE::Success; + } + return ret; +} + +template +static inline __m512 load_T_fp32(const T* srcptr) { + __m512 vtmp; + if constexpr (std::is_same_v) { + vtmp = _mm512_loadu_ps(srcptr); + } else if constexpr (std::is_same_v) { + vtmp = load_bf16_fp32(srcptr); + } else { + assert(0); + } + return vtmp; +} + +static inline __m512 load_s8_fp32(int8_t* srcptr) { + auto src_y = load_s8_s32(srcptr); + auto dst_y = _mm512_cvtepi32_ps(src_y); + return dst_y; +} + +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); + ; +} + +template +static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m512i* iacc, + __m512* facc) { + __m512 v_a_scale[MTILE]; + for (int im = 0; im < MTILE; im++) { + v_a_scale[im] = _mm512_set1_ps(*(asptr + im * ldzp)); + } + + for (int i = 0; i < NReg; i++) { + __m512 v_b_scale = load_T_fp32(bsptr + i * 16); + for (int im = 0; im < MTILE; im++) { + auto vtmp = _mm512_mul_ps(v_a_scale[im], v_b_scale); + auto tmp = _mm512_cvtepi32_ps(iacc[im * NReg + i]); + facc[im * NReg + i] = _mm512_fmadd_ps(tmp, vtmp, facc[im * NReg + i]); + } + } +} + +template +static inline void gemv_remove_zp(const uint8_t* azptr, int ldzp, __m512i* iacc, __m512i* bacc) { + if constexpr (MReg == 1) { + auto zp = int(azptr[0]); + __m512i v_a_zp = _mm512_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm512_mullo_epi32(v_a_zp, bacc[in]); + iacc[in] = _mm512_sub_epi32(iacc[in], vtmp); + } + } else { + __m512i v_a_zp[MReg]; + for (int im = 0; im < MReg; im++) { + auto zp = int(azptr[im * ldzp]); + v_a_zp[im] = _mm512_set1_epi32(zp); + for (int in = 0; in < NReg; in++) { + auto vtmp = _mm512_mullo_epi32(v_a_zp[im], bacc[in]); + iacc[im * NReg + in] = _mm512_sub_epi32(iacc[im * NReg + in], vtmp); + } + } + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc, __m512* vsca) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va = _mm512_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + ftmp = _mm512_mul_ps(ftmp, vsca[i]); + vacc[i] = _mm512_fmadd_ps(va, ftmp, vacc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + ftmp = _mm512_mul_ps(ftmp, vsca[i]); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); + } + vacc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc[im * NReg + i]); + } + } + } + } +} + +template +static inline void accumulate_fp32_s8_fp32(const float* Aptr, int lda, int8_t* Bptr, __m512* vacc_loc) { + if constexpr (MTILE == 1) { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va = _mm512_set1_ps(*(Aptr + ikk)); + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + vacc_loc[i] = _mm512_fmadd_ps(va, ftmp, vacc_loc[i]); + } + } + } else { + for (int ikk = 0; ikk < Unroll; ikk++) { + __m512 va[MTILE]; + for (int i = 0; i < NReg; i++) { + auto ftmp = load_s8_fp32(Bptr + i * 16 + ikk * NReg * 16); + for (int im = 0; im < MTILE; im++) { + if (i == 0) { + va[im] = _mm512_set1_ps(*(Aptr + ikk + im * lda)); + } + vacc_loc[im * NReg + i] = _mm512_fmadd_ps(va[im], ftmp, vacc_loc[im * NReg + i]); + } + } + } + } +} + +template +static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& b4ptr = B.b4ptr; + int blks = k / blocksize; + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + auto vbias = _mm512_set1_epi8(8); + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc, v_b_scale); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } + + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = (utils::bit2x4*)B.b2ptr; + auto b1ptr = (utils::bit1x8*)B.b1ptr; + + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + int constexpr KTILE = 1; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + + __m512 acc_loc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc_loc[i] = _mm512_setzero_ps(); + } + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < Unroll; i++) { + memcpy(tmp + i * NTILE, bzptr, NTILE); + } + for (int i = 0; i < NReg; i++) { + bzp[i] = _mm512_loadu_si512((const __m512i*)(tmp + i * 64)); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + b1ptr += VLen * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + + } else { + for (int ik = 0; ik < blocksize; ik += Unroll) { + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + _mm512_storeu_si512((__m512i*)(tmp + 64 * i), vb); + b2ptr += VLen * Unroll / 4; + b1ptr += VLen * Unroll / 8; + } + accumulate_fp32_s8_fp32(A + ib * blocksize + ik, lda, tmp, acc_loc); + } + } + + __m512 v_b_scale[NReg]; + for (int i = 0; i < NReg; i++) { + v_b_scale[i] = load_T_fp32(bsptr + i * VLen); + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NReg; in++) { + acc[im * NReg + in] = _mm512_fmadd_ps(acc_loc[im * NReg + in], v_b_scale[in], acc[im * NReg + in]); + } + } + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +namespace vnni { + +#if CompileAVX512VNNI() +#ifdef __GNUC__ +#pragma GCC push_options +#pragma GCC target("avx512vnni") +#endif + +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + auto& azptr = A.zpptr; + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + const __m512i onesu8 = _mm512_set1_epi8(1); + const __m512i vbias = _mm512_set1_epi8(8); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + __m512i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + } + } + } + } + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto& a8ptr = A.aptr; + auto& b4ptr = B.b4ptr; + auto& asptr = A.sptr; + + int blks = k / blocksize; + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + // Initialize accumulator with zeros + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + const __m512i vbias = _mm512_set1_epi8(8); + uint32_t mask = 0x0f0f0f0f; + auto vmask = _mm512_set1_epi32(*reinterpret_cast(&mask)); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * VLen, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += 4) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(a8ptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_4bits((void*)(b4ptr + i * 32 + (ib * blocksize + ik) * NTILE / 2), vmask); + vb = _mm512_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + } + } + } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + + const auto onesu8 = _mm512_set1_epi8(1); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + __m512i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr VLen = 16; + int blks = k / blocksize; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + + const auto onesu8 = _mm512_set1_epi8(1); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(2); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); + bzp[i] = _mm512_add_epi8(vbias, bzp[i]); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + vb = _mm512_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; + } + } + } + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + const auto onesu8 = _mm512_set1_epi8(1); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + __m512i bacc[NReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + for (int i = 0; i < NReg; i++) { + bacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + if constexpr (MTILE == 1) { + __m512i va = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + iacc[i] = _mm512_dpbusd_epi32(iacc[i], va, vb); + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } else { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + bacc[i] = _mm512_dpbusd_epi32(bacc[i], onesu8, vb); + for (int j = 0; j < MReg; j++) { + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], va[j], vb); + } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } + } + } + + gemv_remove_zp(A.zpptr + ib, A.ldzp, iacc, bacc); + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + + int blks = k / blocksize; + int constexpr VLen = 16; + int constexpr NReg = NTILE / VLen; + int constexpr MReg = MTILE; + __m512 acc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + acc[i] = _mm512_setzero_ps(); + } + uint64_t mask0 = 0x0303030303030303; + auto vmask0 = _mm512_set1_epi64(*(int64_t*)&mask0); + auto vbias = _mm512_set1_epi8(4); + auto vshift_y = _mm512_set_epi32(6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0, 6, 4, 2, 0); + auto vsfhl_mask_y = _mm512_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, + 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, + 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); + auto vorder_y = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + const auto vindex = _mm512_set_epi8(12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, + 4, 4, 4, 0, 0, 0, 0, 12, 12, 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0, 12, 12, + 12, 12, 8, 8, 8, 8, 4, 4, 4, 4, 0, 0, 0, 0); + int constexpr KTILE = 4; + for (int ib = 0; ib < blks; ib += 1) { + __m512i iacc[NReg * MReg]; + for (int i = 0; i < NReg * MReg; i++) { + iacc[i] = _mm512_setzero_si512(); + } + if (B.zpptr) { + __m512i bzp[NReg]; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int i = 0; i < NReg; i++) { + bzp[i] = load_zp_epi8_broadcast_epi32(bzptr + i * 16, vindex); + bzp[i] = _mm512_add_epi8(bzp[i], vbias); + } + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, bzp[i]); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } + } else { + for (int ik = 0; ik < blocksize; ik += KTILE) { + __m512i va[MReg]; + for (int i = 0; i < MReg; i++) { + va[i] = _mm512_set1_epi32(*(int*)(A.aptr + ib * blocksize + ik + i * A.lda)); + } + for (int i = 0; i < NReg; i++) { + auto vb = unpack_2bits(b2ptr, vshift_y, vmask0, vsfhl_mask_y, vorder_y); + auto vb1 = unpack_1bits(b1ptr, zmm_0x00, zmm_0x04); + vb = _mm512_or_si512(vb, vb1); + vb = _mm512_sub_epi8(vb, vbias); + for (int j = 0; j < MReg; j++) { + auto vsb = _mm512_sign_epi8(vb, va[j]); + auto vabsa = _mm512_sign_epi8(va[j], va[j]); + iacc[j * NReg + i] = _mm512_dpbusd_epi32(iacc[j * NReg + i], vabsa, vsb); + } + b2ptr += VLen * KTILE / 4; + b1ptr += VLen * KTILE / 8; + } + } + } + + gemv_dequant_s32fp32(A.sptr + ib, A.ldzp, B.sptr + ib * B.ldzp, iacc, acc); + } + + for (int j = 0; j < MReg; j++) { + for (int i = 0; i < NReg; i++) { + _mm512_storeu_ps(C + i * VLen + j * ldc, acc[j * NReg + i]); + } + } + return BTLA_CODE::Success; +} + +#ifdef __GNUC__ +#pragma GCC pop_options +#else +#endif +#endif +} // namespace vnni + #ifdef __GNUC__ #pragma GCC pop_options #else diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index fd4900ab6..3347f272f 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -313,7 +313,7 @@ class DecompressS3 { vpsrlw(Xbyak::Ymm(4 + i), bit2_data, 2 * i); vpand(Xbyak::Ymm(4 + i), Xbyak::Ymm(4 + i), lowMask); vpaddb(Xbyak::Ymm(i), Xbyak::Ymm(i), Xbyak::Ymm(4 + i)); - vpslld(Xbyak::Ymm(i), Xbyak::Ymm(i), 5); + vpsubb(Xbyak::Ymm(i), Xbyak::Ymm(i), highMask); if constexpr (std::is_same_v<_DST_T, int8_t>) { vmovdqu(ptr[reg_dst + 32 * i], Xbyak::Ymm(i)); } else if constexpr (std::is_same_v<_DST_T, float>) { diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 489c25737..fb3fb1f65 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -157,8 +157,8 @@ static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstp for (int j = 0; j < row; j++) { for (int ii = 0; ii < col; ii += 2) { utils::int4x2 tmp; - tmp.x = utils::int4x2::convert(srcptr[j * ld_src + ii + 0]); - tmp.y = utils::int4x2::convert(srcptr[j * ld_src + ii + 1]); + tmp.x = utils::int4x2::convert(srcptr[j * ld_src + ii + 0]) + 8; + tmp.y = utils::int4x2::convert(srcptr[j * ld_src + ii + 1]) + 8; dstptr[j * ld_dst / 2 + ii / 2] = tmp; } } @@ -178,13 +178,11 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i return BTLA_CODE::Success; } -static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, - int row, int col, int ld_src, int ld_dst) { +static inline BTLA_CODE compress_3bit_align128(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, + utils::bit1x8* bit1ptr, int row, int col, int ld_src, int ld_dst) { assert(col % 128 == 0); auto round3bit = [](int8_t src) { int32_t dst = src; - dst = dst >= 0 ? dst + 16 : dst - 16; - dst = dst / 32; dst = dst > 3 ? 3 : dst; dst = dst < -4 ? -4 : dst; return static_cast(dst); @@ -204,141 +202,336 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 128) { for (int k = 0; k < 128; k++) { - round_buf[k] = round3bit(const_cast(srcptr + i * ld_src + j + k)[0]) << 5; + round_buf[k] = round3bit(const_cast(srcptr + i * ld_src + j + k)[0]) + 4; } bit2_interleave(round_buf, interleave_buf); for (int k = 0; k < 32; k++) { - bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; - bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; - bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; - bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k]; + bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1]; + bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2]; + bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3]; } for (int k = j; k < j + 128; k += 8) { - bit1ptr[i * ld_dst / 8 + k / 8].a = round_buf[k - j] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].b = round_buf[k - j + 1] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].c = round_buf[k - j + 2] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].d = round_buf[k - j + 3] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].e = round_buf[k - j + 4] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].f = round_buf[k - j + 5] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].g = round_buf[k - j + 6] >> 7; - bit1ptr[i * ld_dst / 8 + k / 8].h = round_buf[k - j + 7] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].a = round_buf[k - j] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].b = round_buf[k - j + 1] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].c = round_buf[k - j + 2] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].d = round_buf[k - j + 3] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].e = round_buf[k - j + 4] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].f = round_buf[k - j + 5] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].g = round_buf[k - j + 6] >> 2; + bit1ptr[i * ld_dst / 8 + k / 8].h = round_buf[k - j + 7] >> 2; } } } return BTLA_CODE::Success; } +static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + size_t size) { + assert(size % 8 == 0); + auto round3bit = [](int8_t src) { + int32_t dst = src; + dst = dst > 3 ? 3 : dst; + dst = dst < -4 ? -4 : dst; + return static_cast(dst); + }; + + for (int j = 0; j < size; j += 8) { + auto tmp = round3bit(srcptr[j + 0]) + 4; + bit2ptr[j / 4 + 0].a = tmp & 0x3; + bit1ptr[j / 8].a = tmp >> 2; + tmp = round3bit(srcptr[j + 1]) + 4; + bit2ptr[j / 4 + 0].b = tmp & 0x3; + bit1ptr[j / 8].b = tmp >> 2; + tmp = round3bit(srcptr[j + 2]) + 4; + bit2ptr[j / 4 + 0].c = tmp & 0x3; + bit1ptr[j / 8].c = tmp >> 2; + tmp = round3bit(srcptr[j + 3]) + 4; + bit2ptr[j / 4 + 0].d = tmp & 0x3; + bit1ptr[j / 8].d = tmp >> 2; + + tmp = round3bit(srcptr[j + 4]) + 4; + bit2ptr[j / 4 + 1].a = tmp & 0x3; + bit1ptr[j / 8].e = tmp >> 2; + tmp = round3bit(srcptr[j + 5]) + 4; + bit2ptr[j / 4 + 1].b = tmp & 0x3; + bit1ptr[j / 8].f = tmp >> 2; + tmp = round3bit(srcptr[j + 6]) + 4; + bit2ptr[j / 4 + 1].c = tmp & 0x3; + bit1ptr[j / 8].g = tmp >> 2; + tmp = round3bit(srcptr[j + 7]) + 4; + bit2ptr[j / 4 + 1].d = tmp & 0x3; + bit1ptr[j / 8].h = tmp >> 2; + } + + return BTLA_CODE::Success; +} + static inline BTLA_CODE compress_2bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, size_t size) { assert(size % 4 == 0); auto round2bit = [](int8_t src) { int32_t dst = src; - dst = dst >= 0 ? dst + 32 : dst - 32; - dst = dst / 64; dst = dst > 1 ? 1 : dst; dst = dst < -2 ? -2 : dst; return static_cast(dst); }; for (size_t i = 0; i < size; i += 4) { - bit2ptr[i / 4].a = round2bit(const_cast(srcptr + i)[0]); - bit2ptr[i / 4].b = round2bit(const_cast(srcptr + i + 1)[0]); - bit2ptr[i / 4].c = round2bit(const_cast(srcptr + i + 2)[0]); - bit2ptr[i / 4].d = round2bit(const_cast(srcptr + i + 3)[0]); + bit2ptr[i / 4].a = round2bit(const_cast(srcptr + i)[0]) + 2; + bit2ptr[i / 4].b = round2bit(const_cast(srcptr + i + 1)[0]) + 2; + bit2ptr[i / 4].c = round2bit(const_cast(srcptr + i + 2)[0]) + 2; + bit2ptr[i / 4].d = round2bit(const_cast(srcptr + i + 3)[0]) + 2; } return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_s4_f32(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales) { - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - auto noffset = i * NTile + j % NTile; - dstptr[i * ld_dst + j + 0] = static_cast(static_cast(tmp.x) << 4) * scales[noffset + 0]; - dstptr[i * ld_dst + j + 1] = static_cast(static_cast(tmp.y) << 4) * scales[noffset + 1]; - } - } - return BTLA_CODE::Success; -} - -template -inline int8_t get_s8(int8_t v) { - static_assert(S4_T == BTLA_DTYPE::S4_CLIP); - return v << 4; -} - -template -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - auto src32 = *reinterpret_cast(srcptr); - auto tmp = static_cast(src32 & 0xf) << 4; - dstptr[0] = tmp; - tmp = static_cast(src32 & 0xf0); - dstptr[1] = tmp; - tmp = static_cast((src32 & 0xf00) >> 4); - dstptr[2] = tmp; - tmp = static_cast((src32 & 0xf000) >> 8); - dstptr[3] = tmp; - tmp = static_cast((src32 & 0xf0000) >> 12); - dstptr[4] = tmp; - tmp = static_cast((src32 & 0xf00000) >> 16); - dstptr[5] = tmp; - tmp = static_cast((src32 & 0xf000000) >> 20); - dstptr[6] = tmp; - tmp = static_cast((src32 & 0xf0000000) >> 24); - dstptr[7] = tmp; -} - -inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) { +template +static inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { auto src32 = *reinterpret_cast(srcptr); auto tmp = static_cast(src32 & 0xf); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[0] = static_cast(tmp); tmp = static_cast(src32 & 0xf0) >> 4; + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[1] = static_cast(tmp); tmp = static_cast((src32 & 0xf00) >> 8); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[2] = static_cast(tmp); tmp = static_cast((src32 & 0xf000) >> 12); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[3] = static_cast(tmp); tmp = static_cast((src32 & 0xf0000) >> 16); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[4] = static_cast(tmp); tmp = static_cast((src32 & 0xf00000) >> 20); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[5] = static_cast(tmp); tmp = static_cast((src32 & 0xf000000) >> 24); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[6] = static_cast(tmp); tmp = static_cast((src32 & 0xf0000000) >> 28); + if constexpr (Q4T == BTLA_DTYPE::S4_CLIP) { + tmp -= 8; + } dstptr[7] = static_cast(tmp); } -template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - convert_s4_s8_8_lowbits(dstptr, srcptr); +static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t unpackelt, int8_t* tmp, + size_t tmpsize) { + for (int j = 0; j < unpackelt; j += 2) { + auto tmp = srcptr[j / 2]; + dstptr[j + 0] = tmp.x - 8; + dstptr[j + 1] = tmp.y - 8; + } + return BTLA_CODE::Success; +} + +static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* dstptr, int unpack_elt, + int8_t* tmp, size_t tmpsize) { + for (size_t i = 0; i < unpack_elt; i += 8) { + auto bit1 = bit1ptr[i / 8]; + auto tmp = bit2ptr[i / 4]; + dstptr[i + 0] = (tmp.a | (bit1.a << 2)) - 4; + dstptr[i + 1] = (tmp.b | (bit1.b << 2)) - 4; + dstptr[i + 2] = (tmp.c | (bit1.c << 2)) - 4; + dstptr[i + 3] = (tmp.d | (bit1.d << 2)) - 4; + tmp = bit2ptr[i / 4 + 1]; + dstptr[i + 4] = (tmp.a | (bit1.e << 2)) - 4; + dstptr[i + 5] = (tmp.b | (bit1.f << 2)) - 4; + dstptr[i + 6] = (tmp.c | (bit1.g << 2)) - 4; + dstptr[i + 7] = (tmp.d | (bit1.h << 2)) - 4; + } + return BTLA_CODE::Success; } -template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - convert_s4_s8_8_lowbits(dstptr, srcptr); +static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* srcptr, int8_t* dstptr, size_t unpackelt, int8_t* tmp, + size_t tmpsize) { + for (int j = 0; j < unpackelt; j += 4) { + auto tmp = srcptr[j / 4]; + dstptr[j + 0] = tmp.a - 2; + dstptr[j + 1] = tmp.b - 2; + dstptr[j + 2] = tmp.c - 2; + dstptr[j + 3] = tmp.d - 2; + } + return BTLA_CODE::Success; } -template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - convert_s4_s8_8_lowbits(dstptr, srcptr); +template +static inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, + int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, + size_t tmpsize) { + if (zpptr) { + if constexpr (PackRow == 4 || PackRow == 2) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 1) { + auto zp = zptr[j] + 8; + for (int ir = 0; ir < PackRow; ir += 2) { + auto tmp = srcptr[i * col / 2 + j * PackRow / 2 + ir / 2]; + dstptr[i * col + j * PackRow + ir + 0] = tmp.x - zp; + dstptr[i * col + j * PackRow + ir + 1] = tmp.y - zp; + } + } + } + } else if constexpr (PackRow == 1) { + for (int i = 0; i < row; i += 1) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 2) { + auto tmp = srcptr[i * col / 2 + j / 2]; + dstptr[i * col + j + 0] = tmp.x - 8 - zptr[j + 0]; + dstptr[i * col + j + 1] = tmp.y - 8 - zptr[j + 1]; + } + } + } else { + static_assert(PackRow == 1 || PackRow == 2 || PackRow == 4); + } + } else { + return decompress_s4_s8(srcptr, dstptr, size_t(row) * col, tmp, tmpsize); + } + return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { - static_assert(S4_T == BTLA_DTYPE::S4_CLIP); - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - dstptr[i * ld_dst + j + 0] = get_s8(tmp.x); - dstptr[i * ld_dst + j + 1] = get_s8(tmp.y); +template +static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, + int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset, + int row, int col, int8_t* tmp, size_t tmpsize) { + static_assert(NTILE % 8 == 0); + assert(((col * PackRow) % 8) == 0); + if (zpptr) { + if constexpr (PackRow == 4) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 2) { + auto zp = zptr[j] + 4; + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + dstptr[i * col + j * PackRow + 0] = (tmp.a | (bit1.a << 2)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.b | (bit1.b << 2)) - zp; + dstptr[i * col + j * PackRow + 2] = (tmp.c | (bit1.c << 2)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.d | (bit1.d << 2)) - zp; + zp = zptr[j + 1] + 4; + tmp = bit2ptr[(i * col + j * PackRow) / 4 + 1]; + dstptr[i * col + j * PackRow + 4] = (tmp.a | (bit1.e << 2)) - zp; + dstptr[i * col + j * PackRow + 5] = (tmp.b | (bit1.f << 2)) - zp; + dstptr[i * col + j * PackRow + 6] = (tmp.c | (bit1.g << 2)) - zp; + dstptr[i * col + j * PackRow + 7] = (tmp.d | (bit1.h << 2)) - zp; + } + } + } else if constexpr (PackRow == 1) { + for (int i = 0; i < row; i += 1) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 8) { + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + dstptr[i * col + j * PackRow + 0] = (tmp.a | (bit1.a << 2)) - 4 - zptr[j + 0]; + dstptr[i * col + j * PackRow + 1] = (tmp.b | (bit1.b << 2)) - 4 - zptr[j + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.c | (bit1.c << 2)) - 4 - zptr[j + 2]; + dstptr[i * col + j * PackRow + 3] = (tmp.d | (bit1.d << 2)) - 4 - zptr[j + 3]; + tmp = bit2ptr[(i * col + j * PackRow) / 4 + 1]; + dstptr[i * col + j * PackRow + 4] = (tmp.a | (bit1.e << 2)) - 4 - zptr[j + 4]; + dstptr[i * col + j * PackRow + 5] = (tmp.b | (bit1.f << 2)) - 4 - zptr[j + 5]; + dstptr[i * col + j * PackRow + 6] = (tmp.c | (bit1.g << 2)) - 4 - zptr[j + 6]; + dstptr[i * col + j * PackRow + 7] = (tmp.d | (bit1.h << 2)) - 4 - zptr[j + 7]; + } + } + } else if constexpr (PackRow == 2) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 4) { + auto bit1 = bit1ptr[(i * col + j * PackRow) / 8]; + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + auto zp = zptr[j] + 4; + dstptr[i * col + j * PackRow + 0] = (tmp.a | (bit1.a << 2)) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.b | (bit1.b << 2)) - zp; + zp = zptr[j + 1] + 4; + dstptr[i * col + j * PackRow + 2] = (tmp.c | (bit1.c << 2)) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.d | (bit1.d << 2)) - zp; + tmp = bit2ptr[(i * col + j * PackRow) / 4 + 1]; + zp = zptr[j + 2] + 4; + dstptr[i * col + j * PackRow + 4] = (tmp.a | (bit1.e << 2)) - zp; + dstptr[i * col + j * PackRow + 5] = (tmp.b | (bit1.f << 2)) - zp; + zp = zptr[j + 3] + 4; + dstptr[i * col + j * PackRow + 6] = (tmp.c | (bit1.g << 2)) - zp; + dstptr[i * col + j * PackRow + 7] = (tmp.d | (bit1.h << 2)) - zp; + } + } + } else { + static_assert(PackRow == 1 || PackRow == 2 || PackRow == 4); } + } else { + return decompress_s3_s8(bit2ptr, bit1ptr, dstptr, size_t(row) * col, tmp, tmpsize); } return BTLA_CODE::Success; } -inline float f8_to_fp32(utils::f8 v, BTLA_DTYPE f8_t) { +template +static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, + int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp, + size_t tmpsize) { + static_assert(NTILE % 4 == 0); + assert(((col * PackRow) % 4) == 0); + if (zpptr) { + if constexpr (PackRow == 4) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 1) { + auto zp = zptr[j] + 2; + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + dstptr[i * col + j * PackRow + 0] = (tmp.a) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.b) - zp; + dstptr[i * col + j * PackRow + 2] = (tmp.c) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.d) - zp; + } + } + } else if constexpr (PackRow == 1) { + for (int i = 0; i < row; i += 1) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 4) { + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + dstptr[i * col + j * PackRow + 0] = (tmp.a) - 2 - zptr[j + 0]; + dstptr[i * col + j * PackRow + 1] = (tmp.b) - 2 - zptr[j + 1]; + dstptr[i * col + j * PackRow + 2] = (tmp.c) - 2 - zptr[j + 2]; + dstptr[i * col + j * PackRow + 3] = (tmp.d) - 2 - zptr[j + 3]; + } + } + } else if constexpr (PackRow == 2) { + for (int i = 0; i < row; i += PackRow) { + auto zptr = zpptr + (i + k_offset) / blocksize * ldzp + n_offset; + for (int j = 0; j < col; j += 2) { + auto tmp = bit2ptr[(i * col + j * PackRow) / 4]; + auto zp = zptr[j] + 2; + dstptr[i * col + j * PackRow + 0] = (tmp.a) - zp; + dstptr[i * col + j * PackRow + 1] = (tmp.b) - zp; + zp = zptr[j + 1] + 2; + dstptr[i * col + j * PackRow + 2] = (tmp.c) - zp; + dstptr[i * col + j * PackRow + 3] = (tmp.d) - zp; + } + } + } else { + static_assert(PackRow == 1 || PackRow == 2 || PackRow == 4); + } + } else { + return decompress_s2_s8(bit2ptr, dstptr, size_t(row) * col, tmp, tmpsize); + } + return BTLA_CODE::Success; +} + +static inline float f8_to_fp32(utils::f8 v, BTLA_DTYPE f8_t) { uint32_t sign_revert = v.x; uint32_t e_revert = v.x; uint32_t mantissa_revert = v.x; @@ -358,8 +551,9 @@ inline float f8_to_fp32(utils::f8 v, BTLA_DTYPE f8_t) { } template -inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, BTLA_DTYPE src_f8_type) { +static inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _S_T* scales, int k_offset, int kblock, int NPad, + BTLA_DTYPE src_f8_type) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -380,57 +574,94 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j += 1) { - float tmp = static_cast(srcptr[i * ld_src + j]); - if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); - dstptr[i * ld_dst + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); +template +inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + for (int i = 0; i < row; i += PackRow) { + int kpos = (k_offset + i) / blocksize * ldzp + n_offset; + if (sdtype == BTLA_DTYPE::F32) { + auto sptr = (float*)scales_ + kpos; + for (int j = 0; j < col; j += 1) { + auto scale = float(sptr[j]); + auto zp = zero_points ? zero_points[kpos + j] : 0; + for (int ir = 0; ir < PackRow; ir++) { + float tmp = static_cast(srcptr[i * col + j * PackRow + ir] - zp) * scale; + dstptr[i * col + j * PackRow + ir] = tmp; + } + } + } else if (sdtype == BTLA_DTYPE::BF16) { + auto sptr = (utils::bf16*)scales_ + kpos; + for (int j = 0; j < col; j += 1) { + auto scale = float(sptr[j]); + auto zp = zero_points ? zero_points[kpos + j] : 0; + for (int ir = 0; ir < PackRow; ir++) { + float tmp = static_cast(srcptr[i * col + j * PackRow + ir] - zp) * scale; + dstptr[i * col + j * PackRow + ir] = tmp; + } + } } } return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, _S_T* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad, int8_t* tmp, size_t tmpsize) { - for (int i = 0; i < row; i++) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - float scale0, scale1, dst0, dst1; - int s0_idx, s1_idx; - s0_idx = j / _PACK_ROW; - s1_idx = (j + 1) / _PACK_ROW; - scale0 = static_cast(sptr[s0_idx]); - scale1 = static_cast(sptr[s1_idx]); - if (zero_points != nullptr) { - dst0 = (static_cast(get_s8(tmp.x)) - static_cast((zero_points + kpos * NPad)[s0_idx])) * - scale0; - dst1 = (static_cast(get_s8(tmp.y)) - static_cast((zero_points + kpos * NPad)[s1_idx])) * - scale1; - } else { - dst0 = static_cast(get_s8(tmp.x)) * scale0; - dst1 = static_cast(get_s8(tmp.y)) * scale1; - } - dstptr[i * ld_dst + j + 0] = static_cast<_DST_T>(dst0); - dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); - } - } +template +static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + assert(tmpsize >= PackRow * NTILE); + assert(NTILE == col); + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + decompress_kblock_s4_s8(srcptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, row, col, + tmp, tmpsize); + decompress_kblock_s8_fp(tmps8ptr, dstptr, row, col, scales_, sdtype, nullptr, k_offset, n_offset, + blocksize, ldzp, tmp, tmpsize); + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s3_fp(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DST_T* dstptr, int row, + int col, void* scales_, BTLA_DTYPE sdtype, int8_t* zero_points, + int k_offset, int n_offset, int blocksize, int ldzp, int8_t* tmp, + size_t tmpsize) { + assert(tmpsize >= PackRow * NTILE); + assert(NTILE == col); + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + decompress_kblock_s3_s8(b2ptr, b1ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, row, + col, tmp, tmpsize); + decompress_kblock_s8_fp(tmps8ptr, dstptr, row, col, scales_, sdtype, nullptr, k_offset, n_offset, + blocksize, ldzp, tmp, tmpsize); + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_s2_fp(utils::bit2x4* b2ptr, DST_T* dstptr, int row, int col, void* scales_, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int blocksize, int ldzp, int8_t* tmp, size_t tmpsize) { + assert(tmpsize >= PackRow * NTILE); + assert(NTILE == col); + const auto DstSize = row * NTILE * sizeof(DST_T); + const auto S8Size = row * NTILE * sizeof(int8_t); + auto tmps8ptr = (int8_t*)dstptr; + tmps8ptr += DstSize - S8Size; + decompress_kblock_s2_s8(b2ptr, zero_points, tmps8ptr, blocksize, ldzp, n_offset, k_offset, row, col, + tmp, tmpsize); + decompress_kblock_s8_fp(tmps8ptr, dstptr, row, col, scales_, sdtype, nullptr, k_offset, n_offset, + blocksize, ldzp, tmp, tmpsize); return BTLA_CODE::Success; } template -inline BTLA_CODE decompress_dq_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, uint8_t* scales, float* dq_scale, int k_offset, int n_offset, - int kblock, int dq_blk, int dq_offset_idx, int NPad, int N, void* tmp, - size_t tmpsize) { +static inline BTLA_CODE decompress_dq_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, uint8_t* scales, float* dq_scale, int k_offset, + int n_offset, int kblock, int dq_blk, int dq_offset_idx, int NPad, + int N, void* tmp, size_t tmpsize) { auto sptr_base = scales + n_offset; for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; @@ -445,8 +676,8 @@ inline BTLA_CODE decompress_dq_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstpt auto dq_s1_idx = (n_offset + kpos * N + s1_idx) / dq_blk; scale0 = dq8_bnb_LUT[sptr[s0_idx]] * dq_scale[dq_s0_idx] + dq_scale[dq_offset_idx]; scale1 = dq8_bnb_LUT[sptr[s1_idx]] * dq_scale[dq_s1_idx] + dq_scale[dq_offset_idx]; - dst0 = static_cast(get_s8(tmp.x)) * scale0; - dst1 = static_cast(get_s8(tmp.y)) * scale1; + dst0 = static_cast(tmp.x - 8) * scale0; + dst1 = static_cast(tmp.y - 8) * scale1; dstptr[i * ld_dst + j + 0] = static_cast<_DST_T>(dst0); dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); } @@ -454,31 +685,7 @@ inline BTLA_CODE decompress_dq_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstpt return BTLA_CODE::Success; } -template -inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - dstptr[i * ld_dst + j + 0] = static_cast<_DST_T>(static_cast(get_s8(tmp.x))); - dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(static_cast(get_s8(tmp.y))); - } - } - return BTLA_CODE::Success; -} - -template -inline BTLA_CODE decompress_kblock_s8_s8fp(int8_t* srcptr, DST_T* dstptr, int row, int col, int ld_src, int ld_dst) { - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 1) { - auto tmp = srcptr[i * ld_src + j]; - dstptr[i * ld_dst + j] = static_cast(static_cast(tmp)); - } - } - return BTLA_CODE::Success; -} - -inline float fp4_bnb_unpack(uint8_t val) { +static inline float fp4_bnb_unpack(uint8_t val) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; if ((val & 0b0100) == 4) // 0 if ((val & 0b0010) == 2) // 01 @@ -501,9 +708,9 @@ inline float fp4_bnb_unpack(uint8_t val) { return 0.00000000f * sign; // 1000 } -inline float fp4_bnb_dequantize(uint8_t val, float absmax) { return fp4_bnb_unpack(val) * absmax; } +static inline float fp4_bnb_dequantize(uint8_t val, float absmax) { return fp4_bnb_unpack(val) * absmax; } -inline int8_t fp4_bnb_quantize(float x) { +static inline int8_t fp4_bnb_quantize(float x) { int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if (x > 0.29166667f) @@ -527,7 +734,7 @@ inline int8_t fp4_bnb_quantize(float x) { return static_cast(0b0000 + sign); } -inline int8_t fp4_e2m1_quantize(float x) { +static inline int8_t fp4_e2m1_quantize(float x) { // FP4 with bias of 1 // first bit is a sign // subnormals @@ -569,7 +776,7 @@ inline int8_t fp4_e2m1_quantize(float x) { } } -inline float fp4_e2m1_unpack(uint8_t val) { +static inline float fp4_e2m1_unpack(uint8_t val) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; if ((val & 0b0100) == 4) // 0 if ((val & 0b0010) == 2) // 01 @@ -592,9 +799,9 @@ inline float fp4_e2m1_unpack(uint8_t val) { return 0.00000000f * sign; // 1000 } -inline float fp4_e2m1_dequantize(uint8_t val, float absmax) { return fp4_e2m1_unpack(val) * absmax; } +static inline float fp4_e2m1_dequantize(uint8_t val, float absmax) { return fp4_e2m1_unpack(val) * absmax; } -inline float nf4_unpack(int8_t val) { +static inline float nf4_unpack(int8_t val) { if ((val & 0b1000) == 8) if ((val & 0b0100) == 4) // 1 if ((val & 0b0010) == 2) // 11 @@ -637,12 +844,12 @@ inline float nf4_unpack(int8_t val) { return 0.f; } -inline float nf4_dequantize(int8_t val, float absmax) { return nf4_unpack(val) * absmax; } +static inline float nf4_dequantize(int8_t val, float absmax) { return nf4_unpack(val) * absmax; } // Note: In the BNB Nf4 definition, 0 has a non-zero value after dequantization, but BTLA uses 0 for padding, which // leads to calculation errors. We ultimately choose to swap the binary bits of -1 and 0 in Nf4 to avoid this // conflict. -inline int8_t nf4_quantize(float x) { +static inline int8_t nf4_quantize(float x) { if (x > 0.03979014977812767f) if (x > 0.3893125355243683f) // 1 if (x > 0.6427869200706482f) // 11 @@ -685,7 +892,7 @@ inline int8_t nf4_quantize(float x) { } template -inline float f4_unpack(int8_t v) { +static inline float f4_unpack(int8_t v) { static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); switch (F4_T) { @@ -702,14 +909,14 @@ inline float f4_unpack(int8_t v) { } template -inline float f4_dequantize(int8_t v, float scale) { +static inline float f4_dequantize(int8_t v, float scale) { static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); return f4_unpack(v) * scale; } template -inline int8_t f4_quantize(float x) { +static inline int8_t f4_quantize(float x) { static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); switch (F4_T) { @@ -726,9 +933,9 @@ inline int8_t f4_quantize(float x) { } template -inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int k_offset, int kblock, int NPad, int8_t* tmp, - size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _S_T* scales, int k_offset, int kblock, int NPad, + int8_t* tmp, size_t tmpsize) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; @@ -750,10 +957,10 @@ inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, in } template -inline BTLA_CODE decompress_dq_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, SCA_T* scales, float* dq_scale, int k_offset, int n_offset, - int kblock, int dq_blk, int dq_offset_idx, int NPad, int N, void* tmp, - size_t tmpsize) { +static inline BTLA_CODE decompress_dq_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, SCA_T* scales, float* dq_scale, int k_offset, + int n_offset, int kblock, int dq_blk, int dq_offset_idx, int NPad, + int N, void* tmp, size_t tmpsize) { auto sptr_base = scales + n_offset; for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; @@ -778,8 +985,8 @@ inline BTLA_CODE decompress_dq_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, } template -inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, int8_t* tmp, size_t tmpsize) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -791,8 +998,8 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, _DST_T* ds } template -inline BTLA_CODE decompress_kblock_f8_fp_noscale(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, BTLA_DTYPE src_f8_t) { +static inline BTLA_CODE decompress_kblock_f8_fp_noscale(utils::f8* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, BTLA_DTYPE src_f8_t) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { dstptr[i * ld_dst + j] = f8_to_fp32(srcptr[i * ld_src + j], src_f8_t); @@ -838,7 +1045,7 @@ static inline BTLA_CODE memcpy2d(const _SRC_T* srcptr, _DST_T* dstptr, int row, return BTLA_CODE::Success; } -static float postop(float x, BTLA_ELTWISEOP op, void* const_elt_v) { +static inline float postop(float x, BTLA_ELTWISEOP op, void* const_elt_v) { if (op == BTLA_ELTWISEOP::GELU) { return 0.5f * x * (1.f + tanhf(0.7978845834732056f * (x + 0.044714998453855515f * x * x * x))); } @@ -878,8 +1085,9 @@ static inline BTLA_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int r } template -inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int blocksize) { +static inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int8_t* zero_points, + int blocksize) { int raw_blocksize = blocksize; for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; @@ -906,11 +1114,11 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst float scale = (maxval - minval) / 255; float rscale = 1.f / scale; scales[j / raw_blocksize * ld_dst + i] = scale; - float fmedium = (maxval + minval) / 2; - int8_t bzp = utils::cast((0 - fmedium) * rscale); + int8_t bzp = utils::cast(utils::cast((0 - minval) * rscale) - 128); zero_points[j / raw_blocksize * ld_dst + i] = bzp; for (size_t ij = 0; ij < blocksize; ij++) { - dstptr[(j + ij) * ld_dst + i] = utils::cast((srcptr[(j + ij) * ld_src + i] - fmedium) * rscale); + dstptr[(j + ij) * ld_dst + i] = + utils::cast(utils::cast((srcptr[(j + ij) * ld_src + i]) * rscale) + bzp); } }; auto sNauto_calc_store_scale_and_quantv_sym = [&](int blocksize) { @@ -930,7 +1138,6 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst if (abs(sum) >= absmax / FullValue) { NVal = sum > 0.f ? -FullValue : FullValue; } - NVal = NVal << (8 - NBits); float scale = absmax / NVal; float rscale = 1.f / scale; scales[j / raw_blocksize * ld_dst + i] = scale; @@ -939,22 +1146,42 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst } }; + auto sNauto_calc_store_scale_and_quantv_asym = [&](int blocksize) { + auto constexpr NBits = utils::bestla_dtype_bits(QDT_T); + int constexpr FullValue = 1 << (NBits - 1); + float maxval = 0.f; + float minval = 0.f; + for (size_t ij = 0; ij < blocksize; ij++) { + maxval = std::max(maxval, srcptr[(j + ij) * ld_src + i]); + minval = std::min(minval, srcptr[(j + ij) * ld_src + i]); + } + float scale = (maxval - minval) / ((1 << NBits) - 1); + float rscale = 1.f / scale; + scales[j / raw_blocksize * ld_dst + i] = scale; + int bzp = utils::cast((0 - minval) * rscale) - FullValue; + auto clip = [&](int s) { + s = std::max(s, -FullValue); + s = std::min(s, FullValue - 1); + return s; + }; + bzp = clip(bzp); + zero_points[j / raw_blocksize * ld_dst + i] = static_cast(bzp); + for (size_t ij = 0; ij < blocksize; ij++) { + auto tmp = utils::cast((srcptr[(j + ij) * ld_src + i]) * rscale) + bzp; + tmp = clip(tmp); + dstptr[(j + ij) * ld_dst + i] = tmp; + } + }; auto dispatch_calc = [&](int blocksize) { switch (QDT_T) { case BTLA_DTYPE::S8: - if (zero_points == nullptr) { - s8_calc_store_scale_and_quantv_sym(blocksize); - } else { - s8_calc_store_scale_and_quantv_asym(blocksize); - } - break; case BTLA_DTYPE::S2_CLIP: case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_CLIP: if (zero_points == nullptr) { sNauto_calc_store_scale_and_quantv_sym(blocksize); } else { - s8_calc_store_scale_and_quantv_asym(blocksize); + sNauto_calc_store_scale_and_quantv_asym(blocksize); } break; default: @@ -970,7 +1197,7 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst } template -int8_t f8_mx_quantize(float v, float scale, BTLA_DTYPE scale_dtype) { +static inline int8_t f8_mx_quantize(float v, float scale, BTLA_DTYPE scale_dtype) { if (scale_dtype == BTLA_DTYPE::F8_E8M0) { v /= std::pow(2, scale); } else { @@ -1012,8 +1239,9 @@ int8_t f8_mx_quantize(float v, float scale, BTLA_DTYPE scale_dtype) { } template -inline BTLA_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int blocksize, BTLA_DTYPE scale_dtype) { +static inline BTLA_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int blocksize, + BTLA_DTYPE scale_dtype) { for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; int j = 0; @@ -1049,8 +1277,8 @@ inline BTLA_CODE quantize_f32_f8_rowblock_mxscale(const float* srcptr, int8_t* d } template -inline BTLA_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int blocksize) { +static inline BTLA_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, + int ld_dst, float* scales, int8_t* zero_points, int blocksize) { int raw_blocksize = blocksize; for (int i = 0; i < col; i++) { int align_row_loop = row / blocksize * blocksize; @@ -1094,8 +1322,9 @@ inline BTLA_CODE quantize_f32_f4_rowblock(const float* srcptr, int8_t* dstptr, i } template -inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, int ld_dst, - float* scales, int ld_scale, uint8_t* zps, int blocksize, float* blkreduce) { +static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, uint8_t* dstptr, + int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, + float* blkreduce) { int colblk = utils::padto_le(col, blocksize); for (int i = 0; i < row; i++) { size_t j = 0; @@ -1154,8 +1383,8 @@ inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* srcptr, } template -inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, int ld_dst, - float* scales, int ld_scale, int blocksize, float* reduce) { +static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, int ld_src, int8_t* dstptr, + int ld_dst, float* scales, int ld_scale, int blocksize, float* reduce) { int colblk = utils::padto_le(col, blocksize); for (int i = 0; i < row; i++) { size_t j = 0; @@ -1198,7 +1427,7 @@ inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* srcptr, return BTLA_CODE::Success; } -inline uint8_t get_dq8_bnb(float v) { +static inline uint8_t get_dq8_bnb(float v) { int left = 0; int right = 255; while (left <= right) { @@ -1220,7 +1449,7 @@ inline uint8_t get_dq8_bnb(float v) { } } template -inline BTLA_CODE dq8_bnb_double_quant(float* scale, size_t scale_size, int dq_blocksize, float* dq_buf) { +static inline BTLA_CODE dq8_bnb_double_quant(float* scale, size_t scale_size, int dq_blocksize, float* dq_buf) { float offset = 0.f; for (int i = 0; i < scale_size; i++) offset += scale[i]; offset /= scale_size; @@ -1249,9 +1478,9 @@ inline BTLA_CODE dq8_bnb_double_quant(float* scale, size_t scale_size, int dq_bl return BTLA_CODE::Success; } -inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, - int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, bool zeropadding, - int mN) { +static inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, + int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, + bool zeropadding, int mN) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { auto dq_s_idx = (i * mN + scale_offset + j) / dq_blk; @@ -1333,8 +1562,8 @@ static inline BTLA_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcste return BTLA_CODE::Success; } -inline BTLA_CODE minmax_f32_kblock(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, int ld_minmax, - int fsize_minmax, int blocksize) { +static inline BTLA_CODE minmax_f32_kblock(const float* srcptr, int row, int col, int ld_src, float* minmaxptr, + int ld_minmax, int fsize_minmax, int blocksize) { for (int i = 0; i < row; i++) { if (col >= blocksize) { for (int icol = 0; icol < col; icol += blocksize) { @@ -1533,8 +1762,9 @@ inline float exp_ps_0_1(float x) { } template -inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, - int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, int8_t* tmp, + size_t tmpsize) { auto head_ignore_num = interleave_n_offset % 128; auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { auto b2ptr = reinterpret_cast(src1); @@ -1546,9 +1776,9 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 for (size_t j = 0; j < 8; j++) { uint8_t bit2 = *(b2ptr + byteoff + j); bit2 >>= bit2off; - uint8_t dst8 = ((bit2 & 0x3) << 5) | ((bit1 & 0x1) << 7); + uint8_t dst8 = ((bit2 & 0x3)) | ((bit1 & 0x1) << 2); bit1 >>= 1; - dst[i + j] = *(int8_t*)&dst8; + dst[i + j] = (*(int8_t*)&dst8) - 4; } } }; @@ -1606,14 +1836,14 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr } template -inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, int unpack_elt, int8_t* tmp, - size_t tmpsize) { +static inline BTLA_CODE decompress_kblock_s2_s8fp(utils::bit2x4* bit2ptr, _DST_T* dstptr, size_t unpack_elt, + int8_t* tmp, size_t tmpsize) { for (size_t i = 0; i < unpack_elt; i += 4) { auto tmp = bit2ptr[i / 4]; - dstptr[i + 0] = _DST_T(tmp.a << 6); - dstptr[i + 1] = _DST_T(tmp.b << 6); - dstptr[i + 2] = _DST_T(tmp.c << 6); - dstptr[i + 3] = _DST_T(tmp.d << 6); + dstptr[i + 0] = _DST_T(tmp.a - 2); + dstptr[i + 1] = _DST_T(tmp.b - 2); + dstptr[i + 2] = _DST_T(tmp.c - 2); + dstptr[i + 3] = _DST_T(tmp.d - 2); } return BTLA_CODE::Success; } @@ -1638,6 +1868,518 @@ static inline BTLA_CODE decompress_kblock_bit2_packrow_fp(utils::bit2x4* bit2ptr return BTLA_CODE::Success; } +template +static inline BTLA_CODE gemv_4bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b4ptr = B.b4ptr; + auto asptr = A.sptr; + auto azptr = A.zpptr; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + if (B.zpptr) { + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += 4) { + for (int im = 0; im < MTILE; im++) { + int azp = azptr[ib + im * A.ldzp]; + float ascale = asptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + auto bv0 = *(utils::int4x2*)(b4ptr + in * 2); + auto bv1 = *(utils::int4x2*)(b4ptr + in * 2 + 1); + auto vscale = ascale * bsptr[in]; + int bzp = bzptr[in] + 8; + accf[im * NTILE + in] += int(a8ptr[0 + im * A.lda] - azp) * (bv0.x - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[1 + im * A.lda] - azp) * (bv0.y - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[2 + im * A.lda] - azp) * (bv1.x - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[3 + im * A.lda] - azp) * (bv1.y - bzp) * vscale; + } + } + a8ptr += 4; + b4ptr += NTILE * 2; + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + for (int im = 0; im < MTILE; im++) { + int azp = azptr[ib + im * A.ldzp]; + float ascale = asptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + auto bv0 = *(utils::int4x2*)(b4ptr + in * 2); + auto bv1 = *(utils::int4x2*)(b4ptr + in * 2 + 1); + auto vscale = ascale * bsptr[in]; + accf[im * NTILE + in] += int(a8ptr[0 + im * A.lda] - azp) * (bv0.x - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[1 + im * A.lda] - azp) * (bv0.y - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[2 + im * A.lda] - azp) * (bv1.x - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[3 + im * A.lda] - azp) * (bv1.y - 8) * vscale; + } + } + a8ptr += 4; + b4ptr += NTILE * 2; + } + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_4bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = reinterpret_cast(A.aptr); + auto b4ptr = B.b4ptr; + auto asptr = A.sptr; + auto azptr = A.zpptr; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + if (B.zpptr) { + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += 4) { + for (int im = 0; im < MTILE; im++) { + float ascale = asptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + auto bv0 = *(utils::int4x2*)(b4ptr + in * 2); + auto bv1 = *(utils::int4x2*)(b4ptr + in * 2 + 1); + auto vscale = ascale * bsptr[in]; + auto bzp = bzptr[in] + 8; + accf[im * NTILE + in] += int(a8ptr[0 + im * A.lda]) * (bv0.x - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[1 + im * A.lda]) * (bv0.y - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[2 + im * A.lda]) * (bv1.x - bzp) * vscale; + accf[im * NTILE + in] += int(a8ptr[3 + im * A.lda]) * (bv1.y - bzp) * vscale; + } + } + a8ptr += 4; + b4ptr += NTILE * 2; + } + } else { + for (int ik = 0; ik < blocksize; ik += 4) { + for (int im = 0; im < MTILE; im++) { + float ascale = asptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + auto bv0 = *(utils::int4x2*)(b4ptr + in * 2); + auto bv1 = *(utils::int4x2*)(b4ptr + in * 2 + 1); + auto vscale = ascale * bsptr[in]; + accf[im * NTILE + in] += int(a8ptr[0 + im * A.lda]) * (bv0.x - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[1 + im * A.lda]) * (bv0.y - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[2 + im * A.lda]) * (bv1.x - 8) * vscale; + accf[im * NTILE + in] += int(a8ptr[3 + im * A.lda]) * (bv1.y - 8) * vscale; + } + } + a8ptr += 4; + b4ptr += NTILE * 2; + } + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_4bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto b4ptr = B.b4ptr; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + if (B.zpptr) { + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += 1) { + for (int im = 0; im < MTILE; im++) { + auto aval = A[ib * blocksize + ik + im * lda]; + for (int in = 0; in < NTILE; in += 2) { + auto bv0 = *(utils::int4x2*)(b4ptr + in / 2); + accf[im * NTILE + in + 0] += aval * (bv0.x - 8 - bzptr[in + 0]) * bsptr[in + 0]; + accf[im * NTILE + in + 1] += aval * (bv0.y - 8 - bzptr[in + 1]) * bsptr[in + 1]; + } + } + b4ptr += NTILE / 2; + } + } else { + for (int ik = 0; ik < blocksize; ik += 1) { + for (int im = 0; im < MTILE; im++) { + auto aval = A[ib * blocksize + ik + im * lda]; + for (int in = 0; in < NTILE; in += 2) { + auto bv0 = *(utils::int4x2*)(b4ptr + in / 2); + accf[im * NTILE + in + 0] += aval * (bv0.x - 8) * bsptr[in + 0]; + accf[im * NTILE + in + 1] += aval * (bv0.y - 8) * bsptr[in + 1]; + } + } + b4ptr += NTILE / 2; + } + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + auto azptr = A.zpptr; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int8_t UnpackBuf[UnpackElt * NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * ld_scaleb; + int acci[NTILE]; + std::memset(acci, 0, sizeof(acci)); + int wacc[NTILE]; + std::memset(wacc, 0, sizeof(wacc)); + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + decompress_kblock_s3_s8fp(b2ptr, b1ptr, UnpackBuf, ik * NTILE, + NTILE * KTILE * UnpackElt, tmp, tmpsize); + for (int iu = 0; iu < UnpackElt; iu++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = UnpackBuf[iu * NTILE * KTILE + in * KTILE + ikt]; + acci[in] += int(a8ptr[iu * KTILE + ikt]) * bval; + wacc[in] += bval; + } + } + } + + b2ptr += KTILE * UnpackElt * NTILE / 4; + b1ptr += KTILE * UnpackElt * NTILE / 8; + a8ptr += KTILE * UnpackElt; + } + float scale = asptr[ib]; + int zp = azptr[ib]; + for (int in = 0; in < NTILE; in++) { + auto tmp = float(acci[in] - zp * wacc[in]); + tmp = tmp * (scale * bsptr[in]); + accf[in] += tmp; + } + } + for (int in = 0; in < NTILE; in++) { + C[in] = accf[in]; + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32_align128(const utils::GemvParamA& A, const utils::GemvParamB& B, + float* C, int k, int ld_scaleb, int blocksize, int8_t* tmp, + size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = reinterpret_cast(A.aptr); + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + auto asptr = A.sptr; + int constexpr EltPadding = 128; + static_assert(NTILE % 8 == 0); + int constexpr KTILE = 4; + int constexpr UnpackElt = EltPadding / 8 / KTILE; + int8_t UnpackBuf[UnpackElt * NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * ld_scaleb; + int acci[NTILE]; + std::memset(acci, 0, sizeof(acci)); + for (int ik = 0; ik < blocksize; ik += KTILE * UnpackElt) { + decompress_kblock_s3_s8fp(b2ptr, b1ptr, UnpackBuf, ik * NTILE, + NTILE * KTILE * UnpackElt, tmp, tmpsize); + for (int iu = 0; iu < UnpackElt; iu++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = UnpackBuf[iu * NTILE * KTILE + in * KTILE + ikt]; + acci[in] += int(a8ptr[iu * KTILE + ikt]) * bval; + } + } + } + b2ptr += KTILE * UnpackElt * NTILE / 4; + b1ptr += KTILE * UnpackElt * NTILE / 8; + a8ptr += KTILE * UnpackElt; + } + + float scale = asptr[ib]; + for (int in = 0; in < NTILE; in++) { + auto tmp = float(acci[in]); + tmp = tmp * (scale * bsptr[in]); + accf[in] += tmp; + } + } + for (int in = 0; in < NTILE; in++) { + C[in] = accf[in]; + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s2_s8<4, NTILE>(b2ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, KTILE, + NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + auto azp = A.zpptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda] - azp) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b2ptr += KTILE * NTILE / 4; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} +template +static inline BTLA_CODE gemv_2bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = (int8_t*)A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s2_s8<4, NTILE>(b2ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, KTILE, + NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda]) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b2ptr += KTILE * NTILE / 4; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_2bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto b2ptr = reinterpret_cast(B.b2ptr); + int constexpr KTILE = 1; + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + int8_t UnpackBuf[NTILE * Unroll]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += Unroll) { + decompress_kblock_s2_s8fp(b2ptr, UnpackBuf, NTILE * Unroll, tmp, tmpsize); + if (B.zpptr) { + for (int ikt = 0; ikt < Unroll; ikt++) { + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + auto bval = (UnpackBuf[in + ikt * NTILE] - bzptr[in]) * bsptr[in]; + auto aval = A[ikt + im * lda]; + accf[im * NTILE + in] += aval * bval; + } + } + } + } else { + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < Unroll; ikt++) { + auto bval = (UnpackBuf[in + ikt * NTILE]) * bsptr[in]; + auto aval = A[ikt + im * lda]; + accf[im * NTILE + in] += aval * bval; + } + } + } + } + b2ptr += Unroll * NTILE / 4; + A += Unroll; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 1; + int constexpr Unroll = 4; + assert((blocksize % 4) == 0); + assert(tmpsize >= NTILE * Unroll); + int8_t UnpackBuf[NTILE * Unroll]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += Unroll) { + decompress_kblock_s3_s8<1, NTILE>(b2ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + Unroll, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < Unroll; ikt++) { + auto bval = (UnpackBuf[in + ikt * NTILE]) * bsptr[in]; + auto aval = A[ikt + im * lda]; + accf[im * NTILE + in] += aval * bval; + } + } + } + b2ptr += Unroll * NTILE / 4; + b1ptr += Unroll * NTILE / 8; + A += Unroll; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s3_s8<4, NTILE>(b2ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + auto azp = A.zpptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda] - azp) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b2ptr += KTILE * NTILE / 4; + b1ptr += KTILE * NTILE / 8; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE gemv_3bit_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, int8_t* tmp, size_t tmpsize) { + int blks = k / blocksize; + float accf[NTILE * MTILE]; + std::memset(accf, 0, sizeof(accf)); + auto a8ptr = (int8_t*)A.aptr; + auto b2ptr = reinterpret_cast(B.b2ptr); + auto b1ptr = reinterpret_cast(B.b1ptr); + int constexpr KTILE = 4; + int8_t UnpackBuf[NTILE * KTILE]; + for (int ib = 0; ib < blks; ib += 1) { + auto bsptr = B.sptr + ib * B.ldzp; + auto bzptr = B.zpptr + ib * B.ldzp; + for (int ik = 0; ik < blocksize; ik += KTILE) { + decompress_kblock_s3_s8<4, NTILE>(b2ptr, b1ptr, B.zpptr ? bzptr : nullptr, UnpackBuf, blocksize, B.ldzp, 0, 0, + KTILE, NTILE, tmp, tmpsize); + for (int im = 0; im < MTILE; im++) { + float ascale = A.sptr[ib + im * A.ldzp]; + for (int in = 0; in < NTILE; in++) { + for (int ikt = 0; ikt < KTILE; ikt++) { + auto bval = (UnpackBuf[in * KTILE + ikt]) * bsptr[in]; + auto aval = int(a8ptr[ikt + im * A.lda]) * ascale; + accf[im * NTILE + in] += aval * bval; + } + } + } + b2ptr += KTILE * NTILE / 4; + b1ptr += KTILE * NTILE / 8; + a8ptr += KTILE; + } + } + for (int im = 0; im < MTILE; im++) { + for (int in = 0; in < NTILE; in++) { + C[in + im * ldc] = accf[im * NTILE + in]; + } + } + return BTLA_CODE::Success; +} } // namespace ref } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index f8751b3c9..beb9ba733 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -278,9 +278,9 @@ class CompressFp4 { class CompressBit3 { public: template - static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, - int col, int ld_src, int ld_dst) { - return ref::compress_3bit(srcptr, bit2ptr, bit1ptr, row, col, ld_src, ld_dst); + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + size_t size) { + return ref::compress_3bit(srcptr, bit2ptr, bit1ptr, size); } }; @@ -306,12 +306,13 @@ class QuantizeSignIntRowBlock { template static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, int blocksize) { -#if CompileAVX512F() - if constexpr (utils::isa_base::avx512f) { - return avx512f::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, blocksize); - } -#endif + // TODO(Yu) simd version for quick quant + // #if CompileAVX512F() + // if constexpr (utils::isa_base::avx512f) { + // return avx512f::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + // zero_points, blocksize); + // } + // #endif return ref::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, blocksize); } @@ -409,149 +410,184 @@ class AccumulateDequantizeS32F32 { } }; -template // zero points always be int8_t, not compressed -class DecompressKBlockS4Fp { +template +class DecompressKBlockS4S8 { public: - template - static inline BTLA_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, - size_t tmpsize) { - BTLA_CODE ret = BTLA_CODE::NotSupport; + template + static inline BTLA_CODE forward(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int col, void* tmp, size_t tmpsize) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - ret = avx512f::decompress_kblock_s4_fp( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); - if (ret == BTLA_CODE::Success) return ret; + return avx512f::decompress_kblock_s4_s8(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); } #endif #if CompileAVX2() - // AVX2 device only focus on fp32 data and layout if constexpr (utils::isa_base::avx2) { - ret = avx2::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, - scales, zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); - if (ret == BTLA_CODE::Success) return ret; + return avx2::decompress_kblock_s4_s8(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, + row, col, (int8_t*)tmp, tmpsize); } #endif - ret = ref::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, - scales, zero_points, k_offset, kblock, NPad, - reinterpret_cast(tmp), tmpsize); - return ret; + return ref::decompress_kblock_s4_s8(srcptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, row, + col, (int8_t*)tmp, tmpsize); } }; -template // zero points always be int8_t, not compressed -class DecompressKBlockS3Fp { +template +class DecompressKBlockS3S8 { public: - template - static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, - int interleave_n_offset, int row, int col, _SCA_T* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { - BTLA_CODE ret = BTLA_CODE::NotSupport; + template + static inline BTLA_CODE forward(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, int8_t* zpptr, int8_t* dstptr, + int blocksize, int ldzp, int n_offset, int k_offset, int row, int col, void* tmp, + size_t tmpsize) { +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s3_s8(b2ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +#endif + return ref::decompress_kblock_s3_s8(b2ptr, b1ptr, zpptr, dstptr, blocksize, ldzp, n_offset, + k_offset, row, col, (int8_t*)tmp, tmpsize); + } +}; + +template +class DecompressKBlockS2S8 { + public: + template + static inline BTLA_CODE forward(utils::bit2x4* b2ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, + int n_offset, int k_offset, int row, int col, void* tmp, size_t tmpsize) { #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - ret = avx512f::decompress_kblock_bit3_packrow_fp( - bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, - tmpsize); - assert(ret == BTLA_CODE::Success); - return ret; + return avx512f::decompress_kblock_s2_s8(b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, + row, col, (int8_t*)tmp, tmpsize); } #endif #if CompileAVX2() if constexpr (utils::isa_base::avx2) { - ret = avx2::decompress_kblock_bit3_packrow_fp( - bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, - tmpsize); - assert(ret == BTLA_CODE::Success); - return ret; + return avx2::decompress_kblock_s2_s8(b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, + row, col, (int8_t*)tmp, tmpsize); } #endif - ret = ref::decompress_kblock_bit3_packrow_fp( - bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, - tmpsize); - assert(ret == BTLA_CODE::Success); - return ret; + return ref::decompress_kblock_s2_s8(b2ptr, zpptr, dstptr, blocksize, ldzp, n_offset, k_offset, row, + col, (int8_t*)tmp, tmpsize); } }; -template // zero points always be int8_t, not compressed -class DecompressKBlockS2Fp { +template +class DecompressKBlockS8Fp { public: - template - static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, _DST_T* dstptr, int row, int col, _SCA_T* scales, - int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { + template + static inline BTLA_CODE forward(int8_t* srcptr, DstT* dstptr, int row, int col, void* scales, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int kblock, int NPad, void* tmp, + size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; - ret = ref::decompress_kblock_bit2_packrow_fp( - bit2ptr, dstptr, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, tmpsize); - assert(ret == BTLA_CODE::Success); +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + ret = avx512f::decompress_kblock_s8_fp(srcptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + if (ret == BTLA_CODE::Success) return ret; + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + ret = avx2::decompress_kblock_s8_fp(srcptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + if (ret == BTLA_CODE::Success) return ret; + } +#endif + ret = ref::decompress_kblock_s8_fp(srcptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); return ret; } }; -template // zero points always be int8_t, not compressed -class DecompressKBlockS4S8Fp { +template +class DecompressKBlockS4Fp { public: - template - static inline BTLA_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - void* tmp, size_t tmpsize) { + template + static inline BTLA_CODE forward(utils::int4x2* srcptr, DstT* dstptr, int row, int col, void* scales, + BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - return avx512f::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, - reinterpret_cast(tmp), tmpsize); + return avx512f::decompress_kblock_s4_fp(srcptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } #endif #if CompileAVX2() if constexpr (utils::isa_base::avx2) { - return avx2::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, - reinterpret_cast(tmp), tmpsize); + return avx2::decompress_kblock_s4_fp(srcptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } #endif - return ref::decompress_kblock_s4_s8fp(srcptr, dstptr, row, col, ld_src, ld_dst, - reinterpret_cast(tmp), tmpsize); + ret = ref::decompress_kblock_s4_fp(srcptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + return ret; } }; -template -class DecompressKBlockS3S8Fp { +template +class DecompressKBlockS3Fp { public: - template - static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, - int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { + template + static inline BTLA_CODE forward(utils::bit2x4* b2ptr, utils::bit1x8* b1ptr, DstT* dstptr, int row, int col, + void* scales, BTLA_DTYPE sdtype, int8_t* zero_points, int k_offset, int n_offset, + int kblock, int NPad, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, - reinterpret_cast(tmp), tmpsize); - assert(ret == BTLA_CODE::Success); - return ret; + return avx512f::decompress_kblock_s3_fp(b2ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } #endif #if CompileAVX2() if constexpr (utils::isa_base::avx2) { - ret = avx2::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, - reinterpret_cast(tmp), tmpsize); - assert(ret == BTLA_CODE::Success); - return ret; + return avx2::decompress_kblock_s3_fp(b2ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); } #endif - ret = ref::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, - reinterpret_cast(tmp), tmpsize); - assert(ret == BTLA_CODE::Success); + ret = ref::decompress_kblock_s3_fp(b2ptr, b1ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); return ret; } }; -template -class DecompressKBlockS2S8Fp { +template +class DecompressKBlockS2Fp { public: - template - static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, _DST_T* dstptr, int unpack_elt, void* tmp, size_t tmpsize) { + template + static inline BTLA_CODE forward(utils::bit2x4* b2ptr, DstT* dstptr, int row, int col, void* scales, BTLA_DTYPE sdtype, + int8_t* zero_points, int k_offset, int n_offset, int kblock, int NPad, void* tmp, + size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; - ret = ref::decompress_kblock_s2_s8fp(bit2ptr, dstptr, unpack_elt, reinterpret_cast(tmp), - tmpsize); - assert(ret == BTLA_CODE::Success); +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + return avx512f::decompress_kblock_s2_fp(b2ptr, dstptr, row, col, scales, sdtype, + zero_points, k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif +#if CompileAVX2() + if constexpr (utils::isa_base::avx2) { + return avx2::decompress_kblock_s2_fp(b2ptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } +#endif + ret = ref::decompress_kblock_s2_fp(b2ptr, dstptr, row, col, scales, sdtype, zero_points, + k_offset, n_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); return ret; } }; @@ -622,27 +658,6 @@ class DecompressKBlockF4FpNoscale { } }; -class DecompressKBlockS4S8 { - public: - template - static inline BTLA_CODE forward(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { - if constexpr (utils::isa_base::avx512f && S4_T == BTLA_DTYPE::S4_CLIP) { - return jit::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); - } -#if CompileAVX512F() - if constexpr (utils::isa_base::avx512f) { - return avx512f::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); - } -#endif -#if CompileAVX2() - if constexpr (utils::isa_base::avx2) { - return avx2::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); - } -#endif - return ref::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); - } -}; - template class DecompressKBlockF8FP { public: @@ -666,50 +681,6 @@ class DecompressKBlockF8FP { } }; -template -class DecompressKBlockS8Fp { - public: - template - static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, void* tmp, - size_t tmpsize) { - if constexpr (utils::isa_base::avx512f && std::is_same_v) { // TODO Scale type support - return jit::DequanKBlockS8Fp::forward_avx512f(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad); - } -#if CompileAVX2() - // PACK_ROW must be 1/4 when using avx2 proB. - if constexpr (utils::isa_base::avx2 && std::is_same_v && - (PACK_ROW == 1 || PACK_ROW == 4)) { // TODO Scale type support - return avx2::dequant_kblock_s8_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad); - } -#endif - return ref::decompress_kblock_s8_fp<_DST_T, PACK_ROW, SCA_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad); - } -}; - -template -class DecompressKBlockS8S8Fp { - public: - template - static inline BTLA_CODE forward(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, void* tmp, - size_t tmpsize) { -#if CompileAVX512F() - if constexpr (utils::isa_base::avx512f) { // TODO Scale type support - return avx512f::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); - } -#endif -#if CompileAVX2() - if constexpr (utils::isa_base::avx2) { // TODO Scale type support - return avx2::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); - } -#endif - return ref::decompress_kblock_s8_s8fp<_DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst); - } -}; - template class DecompressKBlockF8FpNoScale { public: @@ -951,6 +922,169 @@ class LayerNormalization { simplified); } }; + +class GEMVWoqNBits { + public: + template + static inline BTLA_CODE forward_u8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 4) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_4bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_4bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_4bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_4bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 3) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_3bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_3bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_3bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_3bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 2) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_2bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_2bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_2bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_2bit_u8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + return BTLA_CODE::NotSupport; + } + + template + static inline BTLA_CODE forward_s8s8_fp32(const utils::GemvParamA& A, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 4) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_4bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_4bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_4bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 3) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_3bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_3bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_3bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 2) { +#if CompileAVX512VNNI() + if (ISA_T >= BTLA_ISA::AVX512_VNNI) { + return avx512f::vnni::gemv_2bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVXVNNI() + if (ISA_T >= BTLA_ISA::AVX_VNNI) { + return avx2::vnni::gemv_2bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_2bit_s8s8_fp32(A, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + return BTLA_CODE::NotSupport; + } + + template + static inline BTLA_CODE forward_fp32_fp32(const float* A, int lda, const utils::GemvParamB& B, float* C, + int ldc, int k, int blocksize, void* tmp, size_t tmpsize) { + if (B.nbits == 4) { +#if CompileAVX512F() + if (ISA_T >= BTLA_ISA::AVX512F) { + return avx512f::gemv_4bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_4bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_4bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 3) { +#if CompileAVX512F() + if (ISA_T >= BTLA_ISA::AVX512F) { + return avx512f::gemv_3bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_3bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_3bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + if (B.nbits == 2) { +#if CompileAVX512F() + if (ISA_T >= BTLA_ISA::AVX512F) { + return avx512f::gemv_2bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, + tmpsize); + } +#endif +#if CompileAVX2() + if (ISA_T >= BTLA_ISA::AVX2) { + return avx2::gemv_2bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } +#endif + return ref::gemv_2bit_fp32_fp32(A, lda, B, C, ldc, k, blocksize, (int8_t*)tmp, tmpsize); + } + return BTLA_CODE::NotSupport; + } +}; + } // namespace wrapper } // namespace kernel } // namespace bestla diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index ae3f15027..e274c5cee 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -252,6 +252,12 @@ class Benchmark_S8S8S32 { GetCPUDevice(); auto threads_cfg = UT_Threading::get_threads_config(); for (auto threads : threads_cfg) { + if (_cd->AVX_VNNI()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + } if (_cd->AMX_INT8()) { benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); @@ -435,11 +441,20 @@ class UTWOQ_CompFp32 { public: UTWOQ_CompFp32() { UT_START(); + ut_s2(); ut_s4(); - ut_s8(); - ut_f4(); + ut_s3(); + // ut_s8(); + // ut_f4(); + } + void ut_s2() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S2_CLIP); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S2_CLIP); + } + void ut_s3() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S3_CLIP); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S3_CLIP); } - void ut_s4() { benchmark_all(1, 4096, 4096, BTLA_DTYPE::S4_CLIP); benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); @@ -493,7 +508,7 @@ class UTWOQ_CompFp32 { while (tm.stop() < timems) { for (int i = 0; i < batch; i++) { log.start(); - GemmProblem gp(1, m, n, k); + GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; parallel::GemmRun(kernel, args, UT_Threading::get()); log.stop(); @@ -509,63 +524,6 @@ class UTWOQ_CompFp32 { corestr, log.get_log_str(), flops, flops / threads, band); } - template class Wei, typename Scale_T> - void benchmark_mem(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, - int threads, BTLA_DTYPE qtype) { - LOG_T log; - using Parallel = parallel::gemm::SchedulerKBlock; - using Launcher = - wrapper::gemm::LauncherKBlock; - Launcher kernel; - UT_Threading::set_threads(threads); - auto corestr = gemm::CoreAttr::to_str(Core_T::ID); - utils::timer tm; - using WType = typename Wei::StorageWeight; - WType tmpB(0); - if constexpr (std::is_same_v, - prologue_b::gemm::WeightKBlockNInteger>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); - - } else if constexpr (std::is_same_v, - prologue_b::gemm::WeightKBlockNFloat>) { - tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); - } - auto memsize = (size_t)tmpB.mSize + (m * k + m * n) * sizeof(float); - std::vector packBs(batch, 0); - avector bufB(tmpB.mSize * batch); - for (size_t i = 0; i < batch; i++) { - packBs[i] = tmpB; - packBs[i].assign(bufB.data() + i * tmpB.mSize); - } - kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); - for (size_t i = 1; i < batch; i++) { - memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); - } - auto psize = (size_t)m * n * k * 2; - tm.start(); - while (tm.stop() < timems) { - log.start(); - for (size_t i = 0; i < batch; i++) { - GemmProblem gp(1, m, n, k, blocksize); - typename Launcher::Param args{gp, - {A + i * m * k, k}, - {&packBs[i]}, - {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, - {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, UT_Threading::get()); - } - log.stop(); - } - log.record(); - double t = log.min_val / batch; - double flops = double(psize) / t / 1e6; - double band = double(memsize) / t / 1e6; - printf("Threads %d Block %d %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, blocksize, - corestr, flops, flops / threads, band); - } - template