Skip to content

Commit

Permalink
update jblas
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 14, 2023
1 parent bd36779 commit a8e292e
Show file tree
Hide file tree
Showing 20 changed files with 2,524 additions and 2,928 deletions.
64 changes: 51 additions & 13 deletions onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,32 @@ MlasJblasNBitsGemmPackB(void* PackedBuf,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool)
{
GetCPUDevice();
switch (CompType) {
case CompInt8:
return JblaNBitsGemmPackB(JblasAvxVnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, Scale,
Zp, int(N), int(K), isAsym, lastCall, int(ldb), ThreadPool);
if (_cd->AVX512_VNNI()) {
return JblaNBitsGemmPackB(JblasAvx512VnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData,
Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb),
ThreadPool);
}
if (_cd->AVX_VNNI()) {
return JblaNBitsGemmPackB(JblasAvxVnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData,
Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb),
ThreadPool);
}
break;
case CompFp32:
return JblaNBitsGemmPackB(JblasAvx512fS4Fp32Fp32, PackedBuf, int(BlkSize), QData, Scale,
Zp, int(N), int(K), isAsym, lastCall, int(ldb), ThreadPool);
if (_cd->AVX512F()) {
return JblaNBitsGemmPackB(JblasAvx512fS4Fp32Fp32, PackedBuf, int(BlkSize), QData,
Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb),
ThreadPool);
}
if (_cd->AVX2()) {
return JblaNBitsGemmPackB(JblasAvx2S4Fp32Fp32, PackedBuf, int(BlkSize), QData,
Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb),
ThreadPool);
}
break;
case CompBf16:
case CompFp16:
default:
Expand All @@ -177,18 +196,37 @@ MlasJblasQ4GemmUnPackB(float* FpData,
auto ptr =
jblas::storage::gemm::PackedWeightParser::deserialBuffer(const_cast<void*>(PackedBuf));
ORTThreading orth(ThreadPool);
GetCPUDevice();
if (ptr) {
if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
auto coretype = ptr->mCoreType;
auto NTile = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::NTILE_MASK);
auto CType = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::COMP_MASK);
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_FP32)) {
JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
auto NTile =
jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK,
jblas::gemm::CoreAttr::NTILE_SHIFT);
auto CType = jblas::gemm::CoreAttr::get_mask_val(
ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT);
if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) {
if (NTile == 48 && _cd->AVX512F()) {
JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
return;
}
if (NTile == 24 && _cd->AVX2()) {
JblasAvx2S4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
return;
}
}
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_INT8_US)) {
JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) {
if (NTile == 48 && _cd->AVX512_VNNI()) {
JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData,
int(ldb), &orth);
return;
}
if (NTile == 24 && _cd->AVX_VNNI()) {
JblasAvxVnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
return;
}
}
}
delete ptr;
Expand Down
60 changes: 20 additions & 40 deletions onnxruntime/core/mlas/lib/q4gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,30 +155,6 @@ jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func)
[&](ptrdiff_t tid) { func(int(tid)); });
}

template <class Parallel_T, class Launch_T>
void
GemmKBlockRun(Launch_T& launcher,
const typename Launch_T::Param& args,
parallel::IThreading* threading)
{
device::CpuBase cb;
Parallel_T para({
threading->num_threads(),
cb.mL2Cache,
args.M,
args.N,
args.K,
args.KBlock,
});
threading->parallel_for([&](int tidx) {
typename Parallel_T::ThreadProblem thdp{tidx};
para.getIndex(thdp);
if (thdp.valid) {
launcher.run(args, thdp);
}
});
}

template <class GemmCore_T>
void
JblasQ4GemmCompF32(const int M,
Expand All @@ -205,7 +181,7 @@ JblasQ4GemmCompF32(const int M,
reduceA.template get<float>(), reduceA.lda};

typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}};
GemmKBlockRun<Parallel>(kernel, args, th);
jblas::parallel::GemmKBlockRun<Parallel>(kernel, args, th);
}

template <class GemmCore_T>
Expand Down Expand Up @@ -236,10 +212,10 @@ JblasQ4GemmCompInt8(const int M,
{A, K, &quanA},
{B},
{B->template SPtr<int8_t>(), B->mScaT, B->mCStep, quanA.template SPtr<float>(),
quanA.mCStep, quanA.template ZPtr<uint8_t>(), B->template RPtr<float>(),
quanA.mCStep, quanA.template ZPtr<uint8_t>(), B->template RPtr<float>(), B->mRedT,
B->template ZPtr<int8_t>(), quanA.template RPtr<float>(), B->mBlockSize},
{C, N}};
GemmKBlockRun<Parallel>(kernel, args, th);
jblas::parallel::GemmKBlockRun<Parallel>(kernel, args, th);
}

void
Expand All @@ -259,35 +235,39 @@ JblasQ4GemmBatchDriver(const size_t M,
if (ptr) {
if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
auto kptr = (jblas::storage::gemm::StorageWeightKBlockS4*)ptr;
auto coretype = ptr->mCoreType;
auto NTile = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::NTILE_MASK);
auto CType = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::COMP_MASK);
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_FP32)) {
if (_cd->AVX512F()) {
JblasQ4GemmCompF32<gemm::GemmCore_Row_NN_8x48_AVX512F>(
auto coretype = ptr->mCoreId;
auto NTile = jblas::gemm::CoreAttr::get_mask_val(
ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK,
jblas::gemm::CoreAttr::NTILE_SHIFT);
auto CType = jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId,
jblas::gemm::CoreAttr::COMP_MASK,
jblas::gemm::CoreAttr::COMP_SHIFT);
if (CType == uint32_t(gemm::CompType::COMP_FP32)) {
if (NTile == 48 && _cd->AVX512F()) {
JblasQ4GemmCompF32<gemm::SCoreRowNAvx512f<48, 8>>(
M, N, K, DataParams[i].A, DataParams[i].lda,
(jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
return;
}
if (_cd->AVX2()) {
JblasQ4GemmCompF32<gemm::GemmCore_Row_NN_2x48_AVX2>(
if (NTile == 24 && _cd->AVX2()) {
JblasQ4GemmCompF32<gemm::SCoreRowNAvx2<24, 4>>(
M, N, K, DataParams[i].A, DataParams[i].lda,
(jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
return;
}
}
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_INT8_US)) {
if (_cd->AVX512_VNNI()) {
JblasQ4GemmCompInt8<gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>(
if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) {
if (NTile == 48 && _cd->AVX512_VNNI()) {
JblasQ4GemmCompInt8<gemm::ICoreRowNAvx512vnni<48, 8>>(
M, N, K, DataParams[i].A, DataParams[i].lda,
(jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
return;
}
if (_cd->AVX_VNNI()) {
JblasQ4GemmCompInt8<gemm::GemmCore_Row_NN_2x48_AVX_VNNI>(
if (NTile == 24 && _cd->AVX_VNNI()) {
JblasQ4GemmCompInt8<gemm::ICoreRowNAvxvnni<24, 4>>(
M, N, K, DataParams[i].A, DataParams[i].lda,
(jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C,
DataParams[i].ldc, WorkSpace, &orth);
Expand Down
37 changes: 35 additions & 2 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ class JitBase : protected Xbyak::CodeGenerator {
#endif
}

void padto_le(const Xbyak::Reg64& _src, int padding) {
// _src=_src/padding*padding
if (padding == 1) {
return;
}
for (int i = 1; i < 16; i++) {
if ((1 << i) == padding) {
shr(_src, i);
shl(_src, i);
return;
}
}
assert(0);
}

void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
inLocalLabel();
Expand Down Expand Up @@ -86,13 +101,16 @@ class JitBase : protected Xbyak::CodeGenerator {
class JitAvx : protected JitBase {
protected:
static int constexpr VBits = 256;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 16;
typedef Xbyak::Ymm vreg_t;
};

class JitAvx2 : protected JitAvx {
protected:
static int constexpr VBits = 256;
typedef Xbyak::Ymm vreg_t;
void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }

void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
vpmovzxwd(dst, addr);
Expand All @@ -103,8 +121,12 @@ class JitAvx2 : protected JitAvx {
class JitAvx512f : protected JitAvx2 {
protected:
static int constexpr VBits = 512;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 32;
typedef Xbyak::Zmm vreg_t;

void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }

void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
Expand Down Expand Up @@ -191,18 +213,20 @@ class JitAvx512f : protected JitAvx2 {
}
};

class JitAvx512_bf16 : protected JitAvx512f {};

class JitAvx512_fp16 : protected JitAvx512f {};

class JitAvx512vnni : protected JitAvx512f {
protected:
void vpdpbusds_evex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
}
};

class JitAvxvnni : protected JitAvx2 {
protected:
void vpdpbusds_vex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
}
};
Expand All @@ -215,6 +239,15 @@ class JitAmxtile : protected JitAvx512f {
uint16_t colb[16];
uint8_t rows[16];
};
static int constexpr TileCount = 8;

typedef long long (*configure_t)(void*);

static void generate_config(Xbyak::CodeGenerator* g) {
Xbyak::util::StackFrame st(g, 1, 0, 0);
auto& parambase = st.p[0];
g->ldtilecfg(g->ptr[parambase]);
}

static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
int CNum) {
Expand Down
52 changes: 1 addition & 51 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum JBLAS_ISA : uint32_t {
JblasAMX_BF16,
JblasAMX_INT8,
JblasAVX512_FP16,
JblasAVX512_BF16,
};
enum class JBLAS_DTYPE : uint32_t {
EleBitsMask = 0xff,
Expand Down Expand Up @@ -80,57 +81,6 @@ enum JBLAS_ELTWISEOP {
LINEAR,
};

enum class JBLAS_GEMM_CORE : uint32_t {
// INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**|
Undef = 0,
NTILE_MASK = 0xff,
NTILE_SHIFT = 0,
NTILE_24 = 24,
NTILE_48 = 48,
NTILE_64 = 64,
NTILE_96 = 96,
PACKROW_MASK = 0xff00,
PACKROW_SHIFT = 8,
PACKROW_1 = 1 << PACKROW_SHIFT,
PACKROW_2 = 2 << PACKROW_SHIFT,
PACKROW_4 = 4 << PACKROW_SHIFT,
COMP_MASK = 0xff0000,
COMP_SHIFT = 16,
COMP_FP32 = 0 << COMP_SHIFT,
COMP_BF16 = 1 << COMP_SHIFT,
COMP_FP16 = 2 << COMP_SHIFT,
COMP_INT_START = 3 << COMP_SHIFT,
COMP_INT8_US = COMP_INT_START,
COMP_INT8_SS = 4 << COMP_SHIFT,
COMP_INT8_SU = 5 << COMP_SHIFT,
COMP_INT16_SS = 6 << COMP_SHIFT,
COMP_INT8_US_FP32 = 7 << COMP_SHIFT,
COMP_INT8_SS_FP32 = 8 << COMP_SHIFT,
COMP_INT8_SU_FP32 = 9 << COMP_SHIFT,
ISA_MASK = 0xff000000,
ISA_SHIFT = 24,
ISA_AVX2 = (uint32_t)JBLAS_ISA::JblasAVX2 << ISA_SHIFT,
ISA_AVX512F = (uint32_t)JBLAS_ISA::JblasAVX512F << ISA_SHIFT,
ISA_AVX_VNNI = (uint32_t)JBLAS_ISA::JblasAVX_VNNI << ISA_SHIFT,
ISA_AVX512_VNNI = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI << ISA_SHIFT,
ISA_AMX_INT8 = (uint32_t)JBLAS_ISA::JblasAMX_INT8 << ISA_SHIFT,
ISA_AMX_BF16 = (uint32_t)JBLAS_ISA::JblasAMX_BF16 << ISA_SHIFT,
ISA_AVX512_FP16 = (uint32_t)JBLAS_ISA::JblasAVX512_FP16 << ISA_SHIFT,
AVX2_4X24 = NTILE_24 | PACKROW_1 | COMP_FP32 | ISA_AVX2,
AVX2_2X48 = NTILE_48 | PACKROW_1 | COMP_FP32 | ISA_AVX2,
AVX512F_8x48 = NTILE_48 | PACKROW_1 | COMP_FP32 | ISA_AVX512F,
AMX_BF16_16x64 = NTILE_64 | PACKROW_2 | COMP_BF16 | ISA_AMX_BF16,
AMX_BF16_16x48 = NTILE_48 | PACKROW_2 | COMP_BF16 | ISA_AMX_BF16,
AVX512_FP16_8x64 = NTILE_64 | PACKROW_2 | COMP_FP16 | ISA_AVX512_FP16,
AVX512_FP16_8x96 = NTILE_96 | PACKROW_2 | COMP_FP16 | ISA_AVX512_FP16,
AVX_VNNI_2x48 = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AVX_VNNI,
AVX_VNNI_4x24 = NTILE_24 | PACKROW_4 | COMP_INT8_US | ISA_AVX_VNNI,
AVX512_VNNI_8x48 = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AVX512_VNNI,
AMX_INT8_16x64_US = NTILE_64 | PACKROW_4 | COMP_INT8_US | ISA_AMX_INT8,
AMX_INT8_16x64_SS = NTILE_64 | PACKROW_4 | COMP_INT8_SS | ISA_AMX_INT8,
AMX_INT8_16x48_US = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AMX_INT8,
AMX_INT8_16x48_SS = NTILE_48 | PACKROW_4 | COMP_INT8_SS | ISA_AMX_INT8,
};
enum class JBLAS_PROLOGUEB_IDS : uint32_t {
Undef = (uint32_t)-1,
Begin = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class CpuDevice {
}
}
inline int getThreads() { return numthreads; }
inline int getCores() { return numcores; }
inline uint32_t getL2CacheSize() { return L2Cache; }
inline uint32_t getL1CacheSize() { return L1Cache; }
inline bool AVX() { return mHasAVX; }
Expand Down Expand Up @@ -261,15 +262,15 @@ class CpuDevice {

#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance();


class CpuBase {
public:
CpuBase() {
GetCPUDevice();
mL2Cache = _cd->getL2CacheSize();
mL1Cache = _cd->getL1CacheSize();
mNumThreads = _cd->getThreads();
}
size_t mL2Cache;
size_t mL2Cache, mL1Cache;
int mNumThreads;
};
} // namespace device
Expand Down
Loading

0 comments on commit a8e292e

Please sign in to comment.