From 64f36603eee751bfa4cae7c56467d835a5220686 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Mon, 29 Apr 2024 10:39:19 +0800 Subject: [PATCH 1/7] add w4a8 gemm --- marlin/marlin_cuda.cpp | 99 ++++ marlin/w4a8_marlin_cuda_kernel.cu | 954 ++++++++++++++++++++++++++++++ 2 files changed, 1053 insertions(+) create mode 100644 marlin/w4a8_marlin_cuda_kernel.cu diff --git a/marlin/marlin_cuda.cpp b/marlin/marlin_cuda.cpp index a304506..4ae6a51 100644 --- a/marlin/marlin_cuda.cpp +++ b/marlin/marlin_cuda.cpp @@ -1,4 +1,5 @@ /* + * Modified by HandH1998 * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,6 +39,27 @@ int marlin_cuda( int max_par = 16 ); +int w4a8_marlin_cuda( + const void* A, + const void* B, + void* C, // int32 reduce buffer + void* D, // half + void* s1, + void* s2, + void* s3, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +); + const int ERR_PROB_SHAPE = 1; const int ERR_KERN_SHAPE = 2; @@ -88,6 +110,83 @@ void mul( } } +void w4a8_mul( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + torch::Tensor& D, + const torch::Tensor& s1, + const torch::Tensor& s2, + const torch::Tensor& s3, + torch::Tensor& workspace, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 8 +) { + int prob_m = A.size(0); + int prob_n = C.size(1); + int prob_k = A.size(1); + int groupsize = (s3.numel() == 0) ? -1 : prob_k / s3.size(0); + if (groupsize != -1 && groupsize * s3.size(0) != prob_k) + AT_ERROR("k=", prob_k, " not compatible with ", s3.size(0), " groups."); + if (workspace.numel() < prob_n / 128 * max_par) + AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); + int dev = A.get_device(); + int err; + if (s3.numel() == 0) { + err = w4a8_marlin_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + D.data_ptr(), + s1.data_ptr(), + s2.data_ptr(), + nullptr, + prob_m, prob_n, prob_k, + workspace.data_ptr(), + groupsize, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + max_par + ); + } else { + err = w4a8_marlin_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + D.data_ptr(), + s1.data_ptr(), + s2.data_ptr(), + s3.data_ptr(), + prob_m, prob_n, prob_k, + workspace.data_ptr(), + groupsize, + dev, + at::cuda::getCurrentCUDAStream(dev), + thread_k, + thread_n, + sms, + max_par + ); + } + + if (err == ERR_PROB_SHAPE) { + AT_ERROR( + "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", + " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." + ); + } else if (err == ERR_KERN_SHAPE) { + AT_ERROR( + "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." + ); + } +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("mul", &mul, "Marlin FP16xINT4 matmul."); + m.def("w4a8_mul", &w4a8_mul, "Marlin INT8xINT4 matmul."); } diff --git a/marlin/w4a8_marlin_cuda_kernel.cu b/marlin/w4a8_marlin_cuda_kernel.cu new file mode 100644 index 0000000..af7eb90 --- /dev/null +++ b/marlin/w4a8_marlin_cuda_kernel.cu @@ -0,0 +1,954 @@ +/* + * Adapted from https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu + * Modified by HandH1998 + * Copyright (C) 2024 HandH1998 + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS_GROUP = Vec; // weight per-group quantization scales +using FragS_CHANNEL = Vec; // weight per-channel quantization scales or activaton per-token quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} +// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, cp.async.ca can support BYTES = 4, 8, 16 +// as s1's shape is equal to prob_m, we need set s1 to float type, and cp_size = 1 float, i.e., 4 BYTES +// Asynchronous global->shared copy for activation quantizaton scales s1 +__device__ inline void cp_async1_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.ca.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with int8 inputs and int32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + int* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) : "r"(smem) + ); +} + +inline __device__ half2 float2_to_half2(float2 f) { + uint32_t res; + // NOTE(HandH1998): h0,h1 should be uint16_t, not half + uint16_t h0, h1; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); + return reinterpret_cast(res); +} + +inline __device__ float int32_to_float(int h) { + float res; + asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_per_channel(int q) { + static constexpr int MASK = 0xf0f0f0f0; + FragB frag_b; + frag_b[0] = (q & MASK); + return frag_b; +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline uint32_t lop3(uint32_t a, uint32_t b, uint32_t c) { + uint32_t res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// TODO(HandH1998): optimize dequant_per_group, as it doesn't have a very good performance for now +__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { + // convert 4 int8 to 4 half + static constexpr uint32_t LO = 0x000f000f; + static constexpr uint32_t HI = 0x00f000f0; + static constexpr uint32_t EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + static constexpr uint32_t SUB = 0x64086408; + static constexpr uint32_t MUL = 0x2c002c00; + static constexpr uint32_t ADD = 0xd480d480; + *reinterpret_cast(&t0) = __hsub2( + *reinterpret_cast(&t0), + *reinterpret_cast(&SUB) + ); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + + uint16_t s = reinterpret_cast(&frag_s)[i]; + uint32_t double_s; + // pack 2xfp16 to half2 + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); + // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 half, respectively) + static constexpr uint32_t MAGIC_NUM = 0x64806480; + *reinterpret_cast(&t0) = __hfma2( + *reinterpret_cast(&t0), + *reinterpret_cast(&double_s), *reinterpret_cast(&MAGIC_NUM) + ); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), + *reinterpret_cast(&double_s), *reinterpret_cast(&MAGIC_NUM) + ); + // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 int8 into 1 uint32 + FragB frag_b; + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(uint8s) : "r"(t0), "r"(t1), "n"(MASK_0246)); + frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); + return frag_b; +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape (max_par*16*4)xn , as int8 tensor core's output is int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s1, // fp32 activation per-token quantization scales of shape mx1 + const int4* __restrict__ s2, // fp32 weight per-channel quantization scales of shape 1xn + const int4* __restrict__ s3, // fp16 weight per-group quantization scales of shape (k/groupsize)xn, when group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if constexpr (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; + D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + s1 += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 16; + C += 16 * thread_m_blocks * prob_n / 4; + D += 16 * thread_m_blocks * prob_n / 8; + s1 += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 16; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 1 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + constexpr int s1_sh_stride = 16 * thread_m_blocks; + + constexpr int s2_sh_stride = 16 * thread_n_blocks / 4; + + int s3_gl_stride = prob_n / 8; + constexpr int s3_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s3_sh_stage = s3_sh_stride; + int s3_gl_rd_delta = s3_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); + a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s1_gl_rd = threadIdx.x; + // NOTE(HandH1998): activation scale s1 need shuffle to [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] + // for example, 0, 8 row scales serve for thread 0, 1, 2, 3. For more details, refer to mma operand A layout + // as s1's size is not fixed, we can not shuffle before inference + // we shuffle it when fetching s1 from global memory to shared memory, that's why s1_sh_wr is like this + int s1_sh_wr = (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; + int s1_sh_rd = (threadIdx.x % 32) / 4; + bool s1_sh_wr_pred = threadIdx.x < prob_m; + + int s2_gl_rd = s2_sh_stride * slice_col + threadIdx.x; + int s2_sh_wr = threadIdx.x; + int s2_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + 2 * ((threadIdx.x % 32) % 4); + bool s2_sh_wr_pred = threadIdx.x < s2_sh_stride; + + int s3_gl_rd, s3_sh_wr, s3_sh_rd; + bool s3_sh_wr_pred; + if constexpr (group_blocks != -1) { + s3_gl_rd = s3_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s3_sh_stride * slice_col + threadIdx.x; + s3_sh_wr = threadIdx.x; + // NOTE(HandH1998): s3_sh_rd is related to mma output C + s3_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s3_sh_wr_pred = threadIdx.x < s3_sh_stride; + } + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + // NOTE(HandH1998): stages need >= 4, otherwise, sh_s1 = sh + max(stages * a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s1 = sh_b + (stages * b_sh_stage); + int4* sh_s2 = sh_s1 + s1_sh_stride; + int4* sh_s3 = sh_s2 + s2_sh_stride; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS_GROUP frag_s3[2][4]; + FragS_CHANNEL frag_s1[thread_m_blocks]; + FragS_CHANNEL frag_s2[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if constexpr (group_blocks != -1) { + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s3_stage = sh_s3 + s3_sh_stage * pipe; + if (s3_sh_wr_pred) + cp_async4_stream(&sh_s3_stage[s3_sh_wr], &s3[s3_gl_rd]); + s3_gl_rd += s3_gl_rd_delta; + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if constexpr (group_blocks != -1) { + int4* sh_s3_stage = sh_s3 + s3_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s3[k % 2])[0] = sh_s3_stage[s3_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + // int b_quant_shift = b_quant << 4; + FragB frag_b0, frag_b1; + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if constexpr (group_blocks != -1) { + int b_quant_shift = b_quant >> 8; + frag_b0 = dequant_per_group(b_quant, frag_s3[k % 2][j], 0); + frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1); + } else { + int b_quant_shift = b_quant << 4; + frag_b0 = dequant_per_channel(b_quant); + frag_b1 = dequant_per_channel(b_quant_shift); + } + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + int* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + int* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + int* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 4; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 8 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; + c_gl_wr += (4 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads * 2; + int c_sh_wr = 2 * threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i + 1], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + 1], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + reinterpret_cast(&d_red1)[j]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += + reinterpret_cast(&d_red2)[j]; + } + } + if (!last) { + int4 d1, d2; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d1)[j] = + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d2)[j] = + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = d1; + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + 1] = d2; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int d_gl_stride = prob_n / 8; + constexpr int d_sh_stride = 2 * thread_n_blocks + 1; + int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int d_sh_rd_delta = d_sh_stride * (threads / (2 * thread_n_blocks)); + + int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + d_gl_wr += (2 * thread_n_blocks) * slice_col; + int d_sh_wr = (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + d_sh_wr += 32 * (threadIdx.x / 32); + int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int d_gl_wr_end = d_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { + float2 deq_res; + deq_res.x = int32_to_float(c0) * w_s[0] * a_s; + deq_res.y = int32_to_float(c1) * w_s[1] * a_s; + ((half2*) sh)[idx] = float2_to_half2(deq_res); + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = d_sh_wr + 8 * j; + write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s1[i][0], frag_s2[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s1[i][1], frag_s2[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s1[i][0], frag_s2[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s1[i][1], frag_s2[j / 2][2 * (j % 2) + 1]); + } + d_sh_wr += 16 * (4 * d_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (d_gl_wr < d_gl_wr_end) { + D[d_gl_wr] = sh[d_sh_rd]; + d_gl_wr += d_gl_wr_delta; + d_sh_rd += d_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (last) { + if (s1_sh_wr_pred) { + cp_async1_stream(&sh_s1[s1_sh_wr], &s1[s1_gl_rd]); + } + if (s2_sh_wr_pred) { + cp_async4_stream(&sh_s2[s2_sh_wr], &s2[s2_gl_rd]); + } + cp_async_fence(); + } + thread_block_reduce(); + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + frag_s1[i][0] = *reinterpret_cast(&sh_s1[16 * i + 2 * s1_sh_rd]); + frag_s1[i][1] = *reinterpret_cast(&sh_s1[16 * i + 2 * s1_sh_rd + 1]); + } + reinterpret_cast(&frag_s2)[0] = sh_s2[s2_sh_rd + 0]; + reinterpret_cast(&frag_s2)[1] = sh_s2[s2_sh_rd + 1]; + reinterpret_cast(&frag_s2)[2] = sh_s2[s2_sh_rd + 8]; + reinterpret_cast(&frag_s2)[3] = sh_s2[s2_sh_rd + 9]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s3_gl_rd = s3_sh_stride * slice_col + threadIdx.x; + s2_gl_rd = s2_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, D_ptr, s1_ptr, s2_ptr, s3_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int w4a8_marlin_cuda( + const void* A, + const void* B, + void* C, // int32 reduce buffer + void* D, // half + void* s1, + void* s2, + void* s3, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (groupsize == -1) + assert(s3 == nullptr); + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + int4* D_ptr = (int4*) D; + const float* s1_ptr = (const float*) s1; + const int4* s2_ptr = (const int4*) s2; + const int4* s3_ptr = (const int4*) s3; + + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; + D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + s1_ptr += 16 * thread_m_blocks * par; + } + + return ret; +} + + +#endif From 7101bce4f2fcf37907946fab9e7b849ade2dfb53 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Mon, 29 Apr 2024 10:40:22 +0800 Subject: [PATCH 2/7] add W4A8Layer --- marlin/__init__.py | 183 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/marlin/__init__.py b/marlin/__init__.py index b5b7758..f0e2724 100644 --- a/marlin/__init__.py +++ b/marlin/__init__.py @@ -1,3 +1,4 @@ +# Modified by HandH1998 # Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,6 +35,23 @@ def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16): """ marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par) +def w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16): + """INT8xINT4 multiply based on Marlin kernel; can be used within `torch.compile`. + @A: `torch.int8` input matrix of shape `(m, k)` in standard row-major layout + @B: `torch.int` weight matrix of original shape `(k, n)` in the specified format; see `Layer.pack()` + @C: `torch.int` reduce buffer of shape `(max_par * 64, n)` in standard row-major layout + @D: `torch.half` out matrix of shape `(m, n)` in standard row-major layout + @s1: `torch.float` activation per-token quantization scales of shape `(m, 1)` + @s2: `torch.float` weight per-channel quantization scales of shape `(1, n)` + @s3: `torch.half` weight per-group quantization scales of shape `(m / groupsize, n)`, it should be empty when group_size != -1 + @workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero + @thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1) + @thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1) + @sms: number of SMs to use for the kernel (can usually be left as auto -1) + @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes + """ + marlin_cuda.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms, max_par) + # Precompute permutations for Marlin weight and scale shuffling @@ -139,6 +157,171 @@ def pack(self, linear, scales): self.B[:, :] = q.to(self.B.device) self.s[:, :] = s.to(self.s.device) +class W4A8Layer(nn.Module): + """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias.""" + + def __init__(self, infeatures, outfeatures, groupsize=-1): + """Create an empty Marlin layer. + @infeatures: number of input features (must be divisible by 128) + @outfeatures: number of output features (must be divisible by 256) + @groupsize: quantization groupsize (must be -1 or 128) + """ + super().__init__() + if groupsize not in [-1, 128]: + raise ValueError('Only groupsize -1 and 128 are supported.') + if infeatures % 128 != 0 or outfeatures % 256 != 0: + raise ValueError('`infeatures` must be divisible by 128 and `outfeatures` by 256.') + if groupsize == -1: + groupsize = infeatures + if infeatures % groupsize != 0: + raise ValueError('`infeatures` must be divisible by `groupsize`.') + self.k = infeatures + self.n = outfeatures + self.groupsize = groupsize + self.max_par = 16 + self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int)) + self.register_buffer( + "s_channel", + torch.empty( + (1, self.n), + dtype=torch.float, + ), + ) + # if self.groupsize != self.k: + self.register_buffer( + "s_group", + torch.empty( + (self.k // self.groupsize, self.n), dtype=torch.half + ), + ) + self.register_buffer( + "reduce_buffer", + torch.zeros((self.max_par * 16 * 4, self.n), dtype=torch.int), + persistent=False, + ) + # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par` + self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False) + self._perm, self._scale_perm, self._scale_perm_single = self._get_perms() + + # activation int8 quantization + def dynamic_quant(self, x: torch.Tensor): + quant_scale = x.abs().max(dim=-1, keepdim=True)[0].div(127.0).to(torch.float) + x = (x / quant_scale).round().clamp(-128, 127).to(torch.int8) + return x, quant_scale + + def forward(self, A): + out_shape = A.shape[:-1] + (self.n,) + A = A.reshape(-1, A.shape[-1]).half() + quant_A, s1 = self.dynamic_quant(A) + D = torch.empty(A.shape[0], self.n, dtype=A.dtype, device=A.device) + mul( + quant_A, + self.B, + self.reduce_buffer, + D, + s1, + self.s_channel, + self.s_group, + self.workspace, + max_par=self.max_par + ) + D = D.reshape(out_shape) + return D + + def _get_perms(self): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3 + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm) + if self.groupsize == self.k: + interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + # interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + def pack(self, linear, scales, s_extra=None): + """Pack a fake-quantized linear layer into this actual Marlin representation. + @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) + @scales: corresponding quantization scales of shape `(infeatures, groups)` + @s_extra: corresponding quantization scales of shape `(1, outfeatures)` + """ + if self.groupsize != self.k: + assert s_extra is not None, "s_extra is needed" + if linear.weight.dtype != torch.half: + raise ValueError('Only `torch.half` weights are supported.') + tile = 16 + maxq = 15 + s = scales.t() + w = linear.weight.data.t() + if self.groupsize != self.k: + w = w.reshape((-1, self.groupsize, self.n)) + w = w.permute(1, 0, 2) + w = w.reshape((self.groupsize, -1)) + s = s.reshape((1, -1)) + w = torch.round(w / s).int() + # convert int8 to uint8 only for per-group quantization + if self.groupsize != self.k: + w += (maxq + 1) // 2 + w = torch.clamp(w, 0, maxq) + if self.groupsize != self.k: + s_extra = s_extra.reshape(1, -1).to(dtype=torch.float) + s = ( + s.reshape(-1, self.n) / (s_extra) + ).to(dtype=torch.half) + w = w.reshape((self.groupsize, -1, self.n)) + w = w.permute(1, 0, 2) + w = w.reshape((self.k, self.n)).contiguous() + s = s.reshape((-1, len(self._scale_perm)))[:, self._scale_perm] + s_extra = s_extra.reshape((-1, len(self._scale_perm_single)))[ + :, self._scale_perm_single + ] + s_extra = s_extra.reshape((-1, self.n)).contiguous() + else: + s = (s / 16.0).reshape((-1, len(self._scale_perm_single)))[:, self._scale_perm_single] + s = s.reshape((-1, self.n)).contiguous() + w = w.reshape((self.k // tile, tile, self.n // tile, tile)) + w = w.permute((0, 2, 1, 3)) + w = w.reshape((self.k // tile, self.n * tile)) + res = w + res = res.reshape((-1, self._perm.numel()))[:, self._perm].reshape(res.shape) + q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) + res = res.cpu().numpy().astype(np.uint32) + if self.groupsize != self.k: + for i in range(8): + q |= res[:, i::8] << 4 * i + else: + for i in range(8): + q |= (res[:, i::8] & 0xF) << 4 * i + q = torch.from_numpy(q.astype(np.int32)).to(w.device) + self.B[:, :] = q.to(self.B.device) + if self.groupsize != self.k: + self.s_group[:, :] = s.to(self.s_group.device) + self.s_channel[:, :] = s_extra.to(self.s_channel.device) + else: + self.s_group = torch.tensor([], dtype=torch.half, device=self.s_channel.device) + self.s_channel[:, :] = s.to(self.s_channel.device) + def replace_linear(module, name_filter=lambda n: True, groupsize=-1, name=''): """Recursively replace all `torch.nn.Linear` layers by empty Marlin layers. From 113f8983dda85a6ada3147fdf3a683305c198446 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Mon, 29 Apr 2024 10:41:02 +0800 Subject: [PATCH 3/7] add tests for w4a8 --- test_w4a8.py | 182 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 test_w4a8.py diff --git a/test_w4a8.py b/test_w4a8.py new file mode 100644 index 0000000..cbc5b00 --- /dev/null +++ b/test_w4a8.py @@ -0,0 +1,182 @@ +import unittest + +import numpy as np +import torch +import torch.nn as nn + +import marlin + + +seed = 0 +np.random.seed(seed) +torch.random.manual_seed(seed) + +DEV = torch.device('cuda:0') + + +def gen_quant4(m, n, groupsize=-1): + tile = 16 + maxq = 2 ** 4 - 1 + w = torch.randn((m, n), dtype=torch.half, device=DEV) + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + w = torch.round(w / s).int() + w += (maxq + 1) // 2 + w = torch.clamp(w, 0, maxq) + ref = (w - (maxq + 1) // 2).half() * s + if groupsize != -1: + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((m, n)).contiguous() + return w + ref = reshape(ref) + w = reshape(w) + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(m, n) + linear.weight.data = ref.t() + s_extra = ref.t().abs().max(dim=-1, keepdim=True)[0].div(127.0).to(torch.float) + s_extra = s_extra.reshape(1, n) + fake_quant_ref = (ref / s_extra).round().clamp(-128, 127).to(torch.int8) + fake_quant_ref = (fake_quant_ref * s_extra.float()).half() + # Workaround to test some special cases that are forbidden by the API + layer = marlin.W4A8Layer(256, 256, groupsize=groupsize) + if groupsize == -1: + groupsize = m + layer.k = m + layer.n = n + layer.groupsize = groupsize + layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=DEV) + layer.s_group = torch.empty((m // groupsize, n), dtype=torch.half, device=DEV) + layer.s_channel = torch.empty((1, n), dtype=torch.float, device=DEV) + if groupsize == m: + layer.pack(linear, s.t()) + else: + layer.pack(linear, s.t(), s_extra) + q = layer.B + s2 = layer.s_channel + s3 = layer.s_group + return ref, q, fake_quant_ref, s2, s3 + + + +class Test(unittest.TestCase): + + def run_problem(self, m, n, k, thread_k, thread_n, groupsize=-1): + print('% 5d % 6d % 6d % 4d % 4d % 4d' % (m, n, k, thread_k, thread_n, groupsize)) + if k == groupsize: + groupsize = -1 + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=DEV) + B_ref, B, fake_quant_B, s2, s3 = gen_quant4(k, n, groupsize=groupsize) + s1_ref = torch.rand((m, 1), dtype=torch.float, device=DEV) + s1 = s1_ref + max_par = 16 + C = torch.zeros((16 * 4 * max_par, n), dtype=torch.int32, device=DEV) + D = torch.zeros((m, n), dtype=torch.half, device=DEV) + if groupsize == -1: + D_ref = torch.matmul(A.half() * s1_ref.half(), B_ref) + else: + D_ref = torch.matmul(A.half() * s1_ref.half(), fake_quant_B) + workspace = torch.zeros(n // 128 * 16, device=DEV) + marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, -1, max_par=max_par) + torch.cuda.synchronize() + self.assertLess(torch.mean(torch.abs(D - D_ref)) / torch.mean(torch.abs(D_ref)), 0.003) + + def test_tiles(self): + print() + for m in [1, 2, 3, 4, 8, 12, 16, 24, 32, 48, 64, 118, 128, 152, 768, 1024]: + for thread_k, thread_n in [(64, 256), (128, 128)]: + if m > 16 and thread_k == 128: + continue + self.run_problem(m, 2 * 256, 1024, thread_k, thread_n) + + def test_k_stages_divisibility(self): + print() + for k in [3 * 64 + 64 * 4 * 2 + 64 * i for i in range(1, 4)]: + self.run_problem(16, 2 * 256, k, 64, 256) + + def test_very_few_stages(self): + print() + for k in [64, 128, 192]: + self.run_problem(16, 2 * 256, k, 64, 256) + + def test_llama_shapes(self): + print() + MODELS = { + ' 7B': [ + (4096, 3 * 4096), + (4096, 4096), + (4096, 2 * 10752), + (10752, 4096) + ], + '13B': [ + (5120, 3 * 5120), + (5120, 5120), + (5120, 2 * 13568), + (13568, 5120) + ], + '33B': [ + (6656, 3 * 6656), + (6656, 6656), + (6656, 2 * 17664), + (17664, 6656) + ], + '70B': [ + (8192, 3 * 8192), + (8192, 8192), + (8192, 2 * 21760), + (21760, 8192) + ] + } + for _, layers in MODELS.items(): + for layer in layers: + for thread_k, thread_n in [(128, 128)]: + for batch in [1, 16]: + self.run_problem(batch, layer[1], layer[0], thread_k, thread_n, 128) + + def test_errors(self): + print() + m, n, k = 16, 256, 64 + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=DEV) + B_ref, B, fake_quant_B, s2, s3 = gen_quant4(k, n) + s1 = torch.rand((m, 1), dtype=torch.float, device=DEV) + s3 = torch.tensor([], dtype=torch.half, device=DEV) + max_par = 16 + C = torch.zeros((16 * 4 * max_par, n), dtype=torch.int32, device=DEV) + D = torch.zeros((m, n), dtype=torch.half, device=DEV) + workspace = torch.zeros(n // 128, device=DEV) + err = False + try: + marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, 128, 128, -1, max_par=max_par) + except: + err = True + self.assertTrue(err) + err = False + try: + marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, 256, 256, -1, max_par=max_par) + except: + err = True + self.assertTrue(err) + s1 = torch.zeros((2, n), dtype=torch.half, device=DEV) + err = False + try: + marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, 256, 256, -1, max_par=max_par) + except: + err = True + self.assertTrue(err) + + def test_groups(self): + print() + for m in [16]: + for groupsize in [128]: + for n, k in [(256, 512), (256, 1024), (256 * 128, 1024)]: + for thread_shape in [(128, 128), (64, 256)]: + self.run_problem(m, n, k, *thread_shape, groupsize) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 7b23534676b7f56b539a2c66e304e669d1d33360 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Mon, 29 Apr 2024 10:41:23 +0800 Subject: [PATCH 4/7] benchmark for w4a8 gemm --- bench_w4a8.py | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 bench_w4a8.py diff --git a/bench_w4a8.py b/bench_w4a8.py new file mode 100644 index 0000000..29231b3 --- /dev/null +++ b/bench_w4a8.py @@ -0,0 +1,148 @@ +import sys + +import numpy as np +import torch +import marlin + +import time + +def benchmark(f, warmup=1, iter=10): + for i in range(warmup + iter): + f() + # We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also + # happen during realistic model inference as many launches are submitted to the kernel queue. + if i == warmup - 1: + torch.cuda.synchronize() + tick = time.time() + torch.cuda.synchronize() + res = (time.time() - tick) / iter + # Make sure there is enough to "cool down" the GPU in between benchmarks to avoid throttling for later runs when + # we execute many benchmarks consecutively + time.sleep(1.) + return res + +def get_problem(m, n, k, groupsize=-1): + if groupsize == -1: + groupsize = k + dev = torch.device('cuda:0') + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B = torch.randint(low=-2**31, high=2**31, size=(k * n // 8,), device=dev) + B_ref = torch.randn((k, n), dtype=torch.half, device=dev) + max_par = 16 + C = torch.zeros((16 * 4 * max_par, n), dtype=torch.int32, device=dev) + D = torch.zeros((m, n), dtype=torch.half, device=dev) + s1 = torch.ones((m, 1), dtype=torch.float, device=dev) + s2 = torch.ones((1, n), dtype=torch.float, device=dev) + if groupsize == k: + s3 = torch.tensor([], dtype=torch.half, device=dev) + else: + s3 = torch.ones((k // groupsize, n), dtype=torch.half, device=dev) + torch.cuda.synchronize() + return A, B, C, D, A_ref, B_ref, s1, s2, s3 + +def benchmark_dense(A, B, D): + res = benchmark(lambda: torch.matmul(A, B, out=D)) + return { + 's': res, + 'TFLOP/s': 2 * A.numel() * D.shape[1] / res / 10 ** 12, + 'GB/s': (2 * A.numel() + 2 * B.numel() + 2 * D.numel()) / res / 10 ** 9 + } + +def benchmark_quant(A, B, C, D, s1, s2, s3, thread_k, thread_n, sms): + workspace = torch.zeros(D.shape[1] // 128 * 16, device=torch.device('cuda:0')) + res = benchmark(lambda: marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms)) + return { + 's': res, + 'TFLOP/s': 2 * A.numel() * D.shape[1] / res / 10 ** 12, + 'GB/s': (A.numel() + 4 * B.numel() + 2 * D.numel() + 4 * C.numel() + 4 * s1.numel() + 4 * s2.numel() + 2 * s3.numel()) / res / 10 ** 9 + } + +# Pass the SM count for known GPUs to avoid the kernel having to query this information (this is very minor) +gpu = torch.cuda.get_device_name(0) +if 'A100' in gpu: + SMS = 108 +elif 'A10' in gpu: + SMS = 72 +elif '3090' in gpu: + SMS = 82 +elif 'A6000' in gpu: + SMS = 84 +else: + SMS = -1 + +MODELS = { + 'ideal': [ + (4 * 256 * SMS, 256 * SMS) + ], + 'Llama7B': [ + (4096, 3 * 4096), + (4096, 4096), + (4096, 2 * 10752), + (10752, 4096) + ], + 'Llama13B': [ + (5120, 3 * 5120), + (5120, 5120), + (5120, 2 * 13568), + (13568, 5120) + ], + 'Llama33B': [ + (6656, 3 * 6656), + (6656, 6656), + (6656, 2 * 17664), + (17664, 6656) + ], + 'Llama65B': [ + (8192, 3 * 8192), + (8192, 8192), + (8192, 2 * 21760), + (21760, 8192) + ], + 'Falcon180B': [ + # Note that parallel attention and FC allows layer fusions + (14848, 14848 * 5 + 1024), + (14848 * 5, 14848) + ] +} + +# Set to true in order to run a more complete benchmark sweep; the default is reproduce README experiments +ALL = False + +for groupsize in [-1, 128] if ALL else [128]: + print('groupsize=%d' % groupsize) + print() + for model, layers in MODELS.items(): + print(model) + if ALL: + batchsizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 12288] + else: + batchsizes = [1, 2, 4, 8, 16, 32, 64, 128] + for batch in batchsizes: + if not ALL and model != 'ideal' and batch != 16: + continue + tot_q = {'s': 0, 'TFLOP/s': 0, 'GB/s': 0, 'speedup': 0} + for layer in layers: + A, B, C, D, A_ref, B_ref, s1, s2, s3 = get_problem(batch, layer[1], layer[0], groupsize) + res_d = benchmark_dense(A_ref, B_ref, D) + if model == 'ideal' and batch == 16: + # This is a special case constructed to be optimal for a thread-shape different than the default one + res_q = benchmark_quant(A, B, C, D, s1, s2, s3, 64, 256, SMS) + else: + res_q = benchmark_quant(A, B, C, D, s1, s2, s3, -1, -1, SMS) + res_q['speedup'] = res_d['s'] / res_q['s'] + tot_q['s'] += res_q['s'] + for k in tot_q: + if k != 's': + tot_q[k] += res_q[k] * res_q['s'] + for k in tot_q: + if k != 's': + tot_q[k] /= tot_q['s'] + print('batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f' % ( + batch, + tot_q['s'], + tot_q['TFLOP/s'], + tot_q['GB/s'], + tot_q['speedup'] + )) + print() From d9bd464388b155f35de90dd28b852a1aed633be2 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Mon, 29 Apr 2024 10:41:59 +0800 Subject: [PATCH 5/7] modify setup.py --- setup.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 9d870e5..6e88584 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,10 @@ description='Highly optimized FP16xINT4 CUDA matmul kernel.', install_requires=['numpy', 'torch'], packages=['marlin'], - ext_modules=[cpp_extension.CUDAExtension( - 'marlin_cuda', ['marlin/marlin_cuda.cpp', 'marlin/marlin_cuda_kernel.cu'] - )], + ext_modules=[ + cpp_extension.CUDAExtension( + 'marlin_cuda', ['marlin/marlin_cuda.cpp', 'marlin/marlin_cuda_kernel.cu', 'marlin/w4a8_marlin_cuda_kernel.cu'] + ) + ], cmdclass={'build_ext': cpp_extension.BuildExtension}, ) From 6ce9c7c3c46dfe53496bfb179028e885b68b5abe Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Tue, 7 May 2024 17:15:42 +0800 Subject: [PATCH 6/7] clean code --- marlin/marlin_cuda.cpp | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/marlin/marlin_cuda.cpp b/marlin/marlin_cuda.cpp index 4ae6a51..a155f4e 100644 --- a/marlin/marlin_cuda.cpp +++ b/marlin/marlin_cuda.cpp @@ -133,28 +133,7 @@ void w4a8_mul( if (workspace.numel() < prob_n / 128 * max_par) AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); int dev = A.get_device(); - int err; - if (s3.numel() == 0) { - err = w4a8_marlin_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - D.data_ptr(), - s1.data_ptr(), - s2.data_ptr(), - nullptr, - prob_m, prob_n, prob_k, - workspace.data_ptr(), - groupsize, - dev, - at::cuda::getCurrentCUDAStream(dev), - thread_k, - thread_n, - sms, - max_par - ); - } else { - err = w4a8_marlin_cuda( + int err = w4a8_marlin_cuda( A.data_ptr(), B.data_ptr(), C.data_ptr(), @@ -172,7 +151,6 @@ void w4a8_mul( sms, max_par ); - } if (err == ERR_PROB_SHAPE) { AT_ERROR( From 18fbf009e02ee2c0e6c4494cd43ed8b8a7baeb1f Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Thu, 20 Jun 2024 11:58:13 +0800 Subject: [PATCH 7/7] fix issues --- marlin/w4a8_marlin_cuda_kernel.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/marlin/w4a8_marlin_cuda_kernel.cu b/marlin/w4a8_marlin_cuda_kernel.cu index af7eb90..e602c08 100644 --- a/marlin/w4a8_marlin_cuda_kernel.cu +++ b/marlin/w4a8_marlin_cuda_kernel.cu @@ -607,6 +607,9 @@ __global__ void Marlin( // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. + // global_reduce works on INT32 elements, which are the results of INT8 GEMM. + // This is why we need another INT32 maxtrix `C` to reduce instead of the + // original half matrix `D`. auto global_reduce = [&] (bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. // To do this, we write out results in FP16 (but still reduce with FP32 compute). @@ -828,7 +831,7 @@ __global__ void Marlin( // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +// const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ else if ( \ @@ -838,11 +841,11 @@ const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM \ + max_shared_mem \ ); \ Marlin< \ THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ - ><<>>( \ + ><<>>( \ A_ptr, B_ptr, C_ptr, D_ptr, s1_ptr, s2_ptr, s3_ptr, \ prob_m, prob_n, prob_k, \ locks \ @@ -878,6 +881,10 @@ int w4a8_marlin_cuda( if (sms == -1) cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + if (thread_k == -1 || thread_n == -1) { if (prob_m <= 16) { // For small batchizes, better partioning is slightly more important than better compute utilization