From 5f98ab8605851a06a2b6acafc33b5ce53dbbf0d5 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Tue, 21 May 2024 10:34:22 +0800 Subject: [PATCH] add UTs for new bits --- bestla/bestla/ut/bestla_prologue_b.cpp | 96 ++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 7 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index f30422a5d..5ff317e47 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -411,7 +411,7 @@ class UT_TransposeBlockQuantize_F4 { } }; #ifdef BTLA_UT_PROLOGUE_B -static UT_TransposeBlockQuantize_F4 sUT_TransposeBlockQuantize_F4; +static UT_TransposeBlockQuantize_F4 sUT_TransposeBlockQuantize_F4; #endif class UT_BlockQuantize_INT4 { @@ -644,8 +644,6 @@ class UT_CompFp32 { UT_CompFp32() { UT_START(); ut_s6(); - - /* ut_s5(); ut_s4(); ut_s2(); @@ -653,7 +651,7 @@ class UT_CompFp32 { ut_s8(); ut_f4(); - ut_f8();*/ + ut_f8(); } void ut_s2() { @@ -702,6 +700,20 @@ class UT_CompFp32 { false); ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, false); + + CheckISA(AVX512F); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + true); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, + false); } void ut_f8() { @@ -759,6 +771,21 @@ class UT_CompFp32 { true); ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, false); + CheckISA(AVX512F); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32, + false); } void ut_s6() { @@ -777,6 +804,21 @@ class UT_CompFp32 { true); ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, false); + CheckISA(AVX512F); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); + ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16, + false); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + true); + ut_int(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32, + false); } void ut_s8() { @@ -903,16 +945,23 @@ class UT_CompInt8 { UT_CompInt8() { UT_START(); ut_s6(); - - /* ut_s5(); ut_s4(); ut_s2(); - ut_s3();*/ + ut_s3(); } void ut_s2() { GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + } if (_cd->AVX_VNNI()) { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32, true); ut_newkblock>(1, 4096, 4096, 16, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::BF16); @@ -934,10 +983,20 @@ class UT_CompInt8 { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S2_CLIP, BTLA_DTYPE::F32); + } } void ut_s3() { GetCPUDevice(); + if (_cd->AVX2()) { + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); + ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + } if (_cd->AVX_VNNI()) { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, true); @@ -952,6 +1011,10 @@ class UT_CompInt8 { true); ut_newkblock>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32); + } } void ut_s4() { @@ -1009,6 +1072,16 @@ class UT_CompInt8 { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::BF16); ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); } + if (_cd->AVX512_VNNI()) { + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } + + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S5_CLIP, BTLA_DTYPE::F32); + } } void ut_s6() { @@ -1030,6 +1103,15 @@ class UT_CompInt8 { ut_newkblock>(1, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::BF16); ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); } + if (_cd->AVX512_VNNI()) { + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(8, 4096, 4096, 32, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } + if (_cd->AMX_INT8()) { + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S6_CLIP, BTLA_DTYPE::F32); + } } template