From f99d29c4c26d7a3f5777f00a680c7bd8a35e411f Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 8 Apr 2024 16:27:55 +0800 Subject: [PATCH 001/111] add instrinsic for s4_clip --- bestla/bestla/bestla_storage.h | 6 + bestla/bestla/bestla_utils.h | 2 +- bestla/bestla/ut/bestla_benchmark.cpp | 186 +++++++++++++++++++++++++- 3 files changed, 189 insertions(+), 5 deletions(-) diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index 7b13adbe9..7f00f9aa0 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -623,11 +623,17 @@ class StorageQuantActivation : public IActivationKBlockBase { mSize = utils::padto(mSize, Alignment); return mSize; } + template inline constexpr QT_T* APtr() { return mQBuf.get(); } + template + inline constexpr size_t ASize() { + return mQBuf.size(); + } + template inline constexpr QT_T* ZPtr() { return mCorrection.mZpBuf.get(); diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 17e24b75e..35a1915ef 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -592,7 +592,7 @@ using microseconds = std::chrono::microseconds; template class timer { public: - using sclock_t = std::chrono::steady_clock; + using sclock_t = std::chrono::high_resolution_clock; using stime_point_t = std::chrono::time_point; timer() { clear(); } diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index ae3f15027..0273e7755 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -904,8 +904,8 @@ class UTWOQ_CompInt8 { } }; #ifdef BTLA_UT_PROLOGUE_B -#endif static UTWOQ_CompInt8 sUTWOQ_CompInt8; +#endif typedef struct { float d; // delta @@ -916,6 +916,7 @@ typedef struct { uint8_t qs[32]; // nibbles / quants } block_q8_0; #define __AVX2__ +#define __AVXVNNI__ 1 // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t* rsi) { @@ -935,7 +936,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { #if __AVXVNNI__ const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values @@ -1181,7 +1182,7 @@ class UTWOQ_GGML { ut_q40(); } - void ut_q40() { benchmark_all(1, 4096, 4096, BTLA_DTYPE::S4_CLIP); } + void ut_q40() { benchmark_all(1, 4608, 4096, BTLA_DTYPE::S4_CLIP); } template class Wei, typename Scale_T> void benchmark(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, int threads, @@ -1212,10 +1213,10 @@ class UTWOQ_GGML { utils::avector bufferA(quanA.mSize); quanA.assign(bufferA.data()); auto psize = (size_t)m * n * k * 2; - auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); int blks = updiv(k, blocksize); std::vector QB(batch * n * blks); std::vector QA(batch * m * blks); + auto memsize = sizeof(block_q4_0) * blks * n + sizeof(block_q8_0) * blks * m + m * n * sizeof(float); int dr = updiv(n, threads); tm.start(); while (tm.stop() < timems) { @@ -1285,6 +1286,183 @@ class UTWOQ_GGML { } }; static UTWOQ_GGML sUTWOQ_GGML; + +#include "kernel_avx2.h" +template +static void bestla_vec_dot_q4_0_q8_0(const int k_reduce, const int blocksize, float* out, const uint8_t* a_ptr, + const float* a_scale, const uint8_t* b_ptr, const SBT* b_scale, int b_step) { + const int k_blks = k_reduce / blocksize; + int constexpr NReg = NTILE / 8; + // Initialize accumulator with zeros + __m256 acc[NReg]; + for (int i = 0; i < NReg; i++) { + acc[i] = _mm256_setzero_ps(); + } + uint32_t mask = 0xf0f0f0f0; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + + // Main loop + for (int ib = 0; ib < k_blks; ++ib) { + /* Compute combined scale for the block */ + __m256i iacc[NReg]; + for (int i = 0; i < NReg; i++) { + iacc[i] = _mm256_setzero_si256(); + } + for (int ik = 0; ik < blocksize; ik += 4) { + auto va = _mm256_set1_epi32(*(int*)(a_ptr + ib * blocksize + ik)); + for (int i = 0; i < NReg; i++) { + auto vb = + kernel::avx2::unpack_4bits_avx2((void*)(b_ptr + i * 16 + (ib * blocksize + ik) * NTILE / 2), vmask); + iacc[i] = _mm256_dpbusd_avx_epi32(iacc[i], va, vb); + } + } + const __m256 v_a_scale = _mm256_set1_ps(*(a_scale + ib)); + for (int i = 0; i < NReg; i++) { + __m256 v_b_scale; + if constexpr (std::is_same_v) { + v_b_scale = _mm256_loadu_ps(b_scale + ib * b_step + i * 8); + } else if constexpr (std::is_same_v) { + auto tmp = _mm_loadu_si128((const __m128i*)(b_scale + ib * b_step + i * 8)); + v_b_scale = kernel::avx2::ymm_cvt_bf16_fp32(tmp); + } + v_b_scale = _mm256_mul_ps(v_a_scale, v_b_scale); + auto tmp = _mm256_cvtepi32_ps(iacc[i]); + acc[i] = _mm256_fmadd_ps(tmp, v_b_scale, acc[i]); + } + } + for (int i = 0; i < NReg; i++) { + _mm256_storeu_ps(out + i * 8, acc[i]); + } +} + +class UTWOQ_S4_VecDot { + public: + UTWOQ_S4_VecDot() { + UT_START(); + benchmark_all(1, 4608, 4096, BTLA_DTYPE::S4_CLIP); + } + + template class Wei, typename Scale_T> + void benchmark(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, int threads, + BTLA_DTYPE qtype) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + using WType = typename Wei::StorageWeight; + WType tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + } + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); + } + auto quanA = kernel.mProA.createStorage(m, k, blocksize, false); + std::vector As(batch); + utils::avector bufferA(quanA.mSize * batch); + for (size_t i = 0; i < batch; i++) { + As[i] = quanA; + As[i].assign(bufferA.data() + i * quanA.mSize); + } + kernel.mProA.quantize({A, k, &As[0]}, m, k, UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(As[i].template APtr(), As[0].template APtr(), As[0].template ASize()); + memcpy(As[i].template SPtr(), As[0].template SPtr(), As[0].CSize() * sizeof(Scale_T)); + } + auto psize = (size_t)m * n * k * 2; + auto memsize = (size_t)packBs[0].mSize + As[0].mSize + (m * n) * sizeof(float); + assert(m == 1); + parallel::Scheduler2D sch({UT_Threading::get()->num_threads(), 1, n, 1, Core_T::NTILE, 0, 0}); + + tm.start(); + while (tm.stop() < timems) { + for (int i = 0; i < batch; i++) { + log.start(); + auto cbptr = C + i * m * n; + auto awptr = As[i].template APtr(); + auto asptr = As[i].template SPtr(); + auto bwptr = packBs[i].template WPtr(); + auto bsptr = packBs[i].template SPtr(); + UT_Threading::get()->parallel_for([&](int idx) { + parallel::ThreadProblem2D thp{idx}; + sch.getIndex(thp); + if (thp.valid) { + for (int in = 0; in < thp.size[1]; in += Core_T::NTILE) { + bestla_vec_dot_q4_0_q8_0(k, blocksize, cbptr + thp.loc[1] + in, awptr, asptr, + bwptr + (thp.loc[1] + in) * k / 2, bsptr + thp.loc[1] + in, n); + } + } + }); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + double band = double(memsize) / log.min_val / 1e6; + printf("Threads %d Block %d %s %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, blocksize, + corestr, log.get_log_str(), flops, flops / threads, band); + + /*avector refC(m * n); + avector revB(n * k); + kernel.mProB.unpackWeight(n, k, &packBs[0], revB.data(), n, UT_Threading::get()); + gemmref_fp32fp32fp32(m, n, k, A, revB.data(), refC.data(), k, n, n); + buffer_error(refC.data(), C, m * n, 0.01f);*/ + } + + template