Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add UTs for new bits
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed May 21, 2024
1 parent 9752312 commit 5f98ab8
Showing 1 changed file with 89 additions and 7 deletions.
96 changes: 89 additions & 7 deletions bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ class UT_TransposeBlockQuantize_F4 {
}
};
#ifdef BTLA_UT_PROLOGUE_B
static UT_TransposeBlockQuantize_F4 sUT_TransposeBlockQuantize_F4;
static UT_TransposeBlockQuantize_F4 sUT_TransposeBlockQuantize_F4;
#endif

class UT_BlockQuantize_INT4 {
Expand Down Expand Up @@ -644,16 +644,14 @@ class UT_CompFp32 {
UT_CompFp32() {
UT_START();
ut_s6();

/*
ut_s5();
ut_s4();
ut_s2();
ut_s3();
ut_s8();

ut_f4();
ut_f8();*/
ut_f8();
}

void ut_s2() {
Expand Down Expand Up @@ -702,6 +700,20 @@ class UT_CompFp32 {
false);
ut_int<sAVX2, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
false);

CheckISA(AVX512F);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32,
false);
}

void ut_f8() {
Expand Down Expand Up @@ -759,6 +771,21 @@ class UT_CompFp32 {
true);
ut_int<sAVX2, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
false);
CheckISA(AVX512F);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, -1, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32,
false);
}

void ut_s6() {
Expand All @@ -777,6 +804,21 @@ class UT_CompFp32 {
true);
ut_int<sAVX2, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
false);
CheckISA(AVX512F);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, -1, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16,
false);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
true);
ut_int<sAVX512F, prologue_b::gemm::WeightKBlockNInteger>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32,
false);
}

void ut_s8() {
Expand Down Expand Up @@ -903,16 +945,23 @@ class UT_CompInt8 {
UT_CompInt8() {
UT_START();
ut_s6();

/*
ut_s5();
ut_s4();
ut_s2();
ut_s3();*/
ut_s3();
}

void ut_s2() {
GetCPUDevice();
if (_cd->AVX2()) {
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(2, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AVX_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16);
Expand All @@ -934,10 +983,20 @@ class UT_CompInt8 {
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AMX_INT8()) {
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(1, 4096, 4096, 64, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32);
}
}

void ut_s3() {
GetCPUDevice();
if (_cd->AVX2()) {
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true);
ut_newkblock<gemm::ICoreRowNAvx2vnniKBlock<24, 2>>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AVX_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true);
Expand All @@ -952,6 +1011,10 @@ class UT_CompInt8 {
true);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AMX_INT8()) {
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(1, 4096, 4096, 64, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32);
}
}

void ut_s4() {
Expand Down Expand Up @@ -1009,6 +1072,16 @@ class UT_CompInt8 {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AVX512_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 11008, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
}

if (_cd->AMX_INT8()) {
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(1, 4096, 4096, 64, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32);
}
}

void ut_s6() {
Expand All @@ -1030,6 +1103,15 @@ class UT_CompInt8 {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<24, 2>>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AVX512_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 11008, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
}
if (_cd->AMX_INT8()) {
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(1, 4096, 4096, 64, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32);
}
}

template <class GemmCore_T>
Expand Down

0 comments on commit 5f98ab8

Please sign in to comment.