Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support w4a8 marlin gemm #31

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions bench_w4a8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import sys

import numpy as np
import torch
import marlin

import time

def benchmark(f, warmup=1, iter=10):
for i in range(warmup + iter):
f()
# We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also
# happen during realistic model inference as many launches are submitted to the kernel queue.
if i == warmup - 1:
torch.cuda.synchronize()
tick = time.time()
torch.cuda.synchronize()
res = (time.time() - tick) / iter
# Make sure there is enough to "cool down" the GPU in between benchmarks to avoid throttling for later runs when
# we execute many benchmarks consecutively
time.sleep(1.)
return res

def get_problem(m, n, k, groupsize=-1):
if groupsize == -1:
groupsize = k
dev = torch.device('cuda:0')
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B = torch.randint(low=-2**31, high=2**31, size=(k * n // 8,), device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)
max_par = 16
C = torch.zeros((16 * 4 * max_par, n), dtype=torch.int32, device=dev)
D = torch.zeros((m, n), dtype=torch.half, device=dev)
s1 = torch.ones((m, 1), dtype=torch.float, device=dev)
s2 = torch.ones((1, n), dtype=torch.float, device=dev)
if groupsize == k:
s3 = torch.tensor([], dtype=torch.half, device=dev)
else:
s3 = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
torch.cuda.synchronize()
return A, B, C, D, A_ref, B_ref, s1, s2, s3

def benchmark_dense(A, B, D):
res = benchmark(lambda: torch.matmul(A, B, out=D))
return {
's': res,
'TFLOP/s': 2 * A.numel() * D.shape[1] / res / 10 ** 12,
'GB/s': (2 * A.numel() + 2 * B.numel() + 2 * D.numel()) / res / 10 ** 9
}

def benchmark_quant(A, B, C, D, s1, s2, s3, thread_k, thread_n, sms):
workspace = torch.zeros(D.shape[1] // 128 * 16, device=torch.device('cuda:0'))
res = benchmark(lambda: marlin.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms))
return {
's': res,
'TFLOP/s': 2 * A.numel() * D.shape[1] / res / 10 ** 12,
'GB/s': (A.numel() + 4 * B.numel() + 2 * D.numel() + 4 * C.numel() + 4 * s1.numel() + 4 * s2.numel() + 2 * s3.numel()) / res / 10 ** 9
}

# Pass the SM count for known GPUs to avoid the kernel having to query this information (this is very minor)
gpu = torch.cuda.get_device_name(0)
if 'A100' in gpu:
SMS = 108
elif 'A10' in gpu:
SMS = 72
elif '3090' in gpu:
SMS = 82
elif 'A6000' in gpu:
SMS = 84
else:
SMS = -1

MODELS = {
'ideal': [
(4 * 256 * SMS, 256 * SMS)
],
'Llama7B': [
(4096, 3 * 4096),
(4096, 4096),
(4096, 2 * 10752),
(10752, 4096)
],
'Llama13B': [
(5120, 3 * 5120),
(5120, 5120),
(5120, 2 * 13568),
(13568, 5120)
],
'Llama33B': [
(6656, 3 * 6656),
(6656, 6656),
(6656, 2 * 17664),
(17664, 6656)
],
'Llama65B': [
(8192, 3 * 8192),
(8192, 8192),
(8192, 2 * 21760),
(21760, 8192)
],
'Falcon180B': [
# Note that parallel attention and FC allows layer fusions
(14848, 14848 * 5 + 1024),
(14848 * 5, 14848)
]
}

# Set to true in order to run a more complete benchmark sweep; the default is reproduce README experiments
ALL = False

for groupsize in [-1, 128] if ALL else [128]:
print('groupsize=%d' % groupsize)
print()
for model, layers in MODELS.items():
print(model)
if ALL:
batchsizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 12288]
else:
batchsizes = [1, 2, 4, 8, 16, 32, 64, 128]
for batch in batchsizes:
if not ALL and model != 'ideal' and batch != 16:
continue
tot_q = {'s': 0, 'TFLOP/s': 0, 'GB/s': 0, 'speedup': 0}
for layer in layers:
A, B, C, D, A_ref, B_ref, s1, s2, s3 = get_problem(batch, layer[1], layer[0], groupsize)
res_d = benchmark_dense(A_ref, B_ref, D)
if model == 'ideal' and batch == 16:
# This is a special case constructed to be optimal for a thread-shape different than the default one
res_q = benchmark_quant(A, B, C, D, s1, s2, s3, 64, 256, SMS)
else:
res_q = benchmark_quant(A, B, C, D, s1, s2, s3, -1, -1, SMS)
res_q['speedup'] = res_d['s'] / res_q['s']
tot_q['s'] += res_q['s']
for k in tot_q:
if k != 's':
tot_q[k] += res_q[k] * res_q['s']
for k in tot_q:
if k != 's':
tot_q[k] /= tot_q['s']
print('batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f' % (
batch,
tot_q['s'],
tot_q['TFLOP/s'],
tot_q['GB/s'],
tot_q['speedup']
))
print()
183 changes: 183 additions & 0 deletions marlin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Modified by HandH1998
# Copyright (C) Marlin.2024 Elias Frantar ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -34,6 +35,23 @@ def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""
marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)

def w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""INT8xINT4 multiply based on Marlin kernel; can be used within `torch.compile`.
@A: `torch.int8` input matrix of shape `(m, k)` in standard row-major layout
@B: `torch.int` weight matrix of original shape `(k, n)` in the specified format; see `Layer.pack()`
@C: `torch.int` reduce buffer of shape `(max_par * 64, n)` in standard row-major layout
@D: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
@s1: `torch.float` activation per-token quantization scales of shape `(m, 1)`
@s2: `torch.float` weight per-channel quantization scales of shape `(1, n)`
@s3: `torch.half` weight per-group quantization scales of shape `(m / groupsize, n)`, it should be empty when group_size != -1
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
"""
marlin_cuda.w4a8_mul(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms, max_par)


# Precompute permutations for Marlin weight and scale shuffling

Expand Down Expand Up @@ -139,6 +157,171 @@ def pack(self, linear, scales):
self.B[:, :] = q.to(self.B.device)
self.s[:, :] = s.to(self.s.device)

class W4A8Layer(nn.Module):
"""PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""

def __init__(self, infeatures, outfeatures, groupsize=-1):
"""Create an empty Marlin layer.
@infeatures: number of input features (must be divisible by 128)
@outfeatures: number of output features (must be divisible by 256)
@groupsize: quantization groupsize (must be -1 or 128)
"""
super().__init__()
if groupsize not in [-1, 128]:
raise ValueError('Only groupsize -1 and 128 are supported.')
if infeatures % 128 != 0 or outfeatures % 256 != 0:
raise ValueError('`infeatures` must be divisible by 128 and `outfeatures` by 256.')
if groupsize == -1:
groupsize = infeatures
if infeatures % groupsize != 0:
raise ValueError('`infeatures` must be divisible by `groupsize`.')
self.k = infeatures
self.n = outfeatures
self.groupsize = groupsize
self.max_par = 16
self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int))
self.register_buffer(
"s_channel",
torch.empty(
(1, self.n),
dtype=torch.float,
),
)
# if self.groupsize != self.k:
self.register_buffer(
"s_group",
torch.empty(
(self.k // self.groupsize, self.n), dtype=torch.half
),
)
self.register_buffer(
"reduce_buffer",
torch.zeros((self.max_par * 16 * 4, self.n), dtype=torch.int),
persistent=False,
)
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False)
self._perm, self._scale_perm, self._scale_perm_single = self._get_perms()

# activation int8 quantization
def dynamic_quant(self, x: torch.Tensor):
quant_scale = x.abs().max(dim=-1, keepdim=True)[0].div(127.0).to(torch.float)
x = (x / quant_scale).round().clamp(-128, 127).to(torch.int8)
return x, quant_scale

def forward(self, A):
out_shape = A.shape[:-1] + (self.n,)
A = A.reshape(-1, A.shape[-1]).half()
quant_A, s1 = self.dynamic_quant(A)
D = torch.empty(A.shape[0], self.n, dtype=A.dtype, device=A.device)
mul(
quant_A,
self.B,
self.reduce_buffer,
D,
s1,
self.s_channel,
self.s_group,
self.workspace,
max_par=self.max_par
)
D = D.reshape(out_shape)
return D

def _get_perms(self):
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
4 * (i % 4),
4 * (i % 4) + 1,
4 * (i % 4) + 2,
4 * (i % 4) + 3
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])

perm = np.array(perm)
if self.groupsize == self.k:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
else:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
# interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single

def pack(self, linear, scales, s_extra=None):
"""Pack a fake-quantized linear layer into this actual Marlin representation.
@linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
@scales: corresponding quantization scales of shape `(infeatures, groups)`
@s_extra: corresponding quantization scales of shape `(1, outfeatures)`
"""
if self.groupsize != self.k:
assert s_extra is not None, "s_extra is needed"
if linear.weight.dtype != torch.half:
raise ValueError('Only `torch.half` weights are supported.')
tile = 16
maxq = 15
s = scales.t()
w = linear.weight.data.t()
if self.groupsize != self.k:
w = w.reshape((-1, self.groupsize, self.n))
w = w.permute(1, 0, 2)
w = w.reshape((self.groupsize, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
# convert int8 to uint8 only for per-group quantization
if self.groupsize != self.k:
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if self.groupsize != self.k:
s_extra = s_extra.reshape(1, -1).to(dtype=torch.float)
s = (
s.reshape(-1, self.n) / (s_extra)
).to(dtype=torch.half)
w = w.reshape((self.groupsize, -1, self.n))
w = w.permute(1, 0, 2)
w = w.reshape((self.k, self.n)).contiguous()
s = s.reshape((-1, len(self._scale_perm)))[:, self._scale_perm]
s_extra = s_extra.reshape((-1, len(self._scale_perm_single)))[
:, self._scale_perm_single
]
s_extra = s_extra.reshape((-1, self.n)).contiguous()
else:
s = (s / 16.0).reshape((-1, len(self._scale_perm_single)))[:, self._scale_perm_single]
s = s.reshape((-1, self.n)).contiguous()
w = w.reshape((self.k // tile, tile, self.n // tile, tile))
w = w.permute((0, 2, 1, 3))
w = w.reshape((self.k // tile, self.n * tile))
res = w
res = res.reshape((-1, self._perm.numel()))[:, self._perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
if self.groupsize != self.k:
for i in range(8):
q |= res[:, i::8] << 4 * i
else:
for i in range(8):
q |= (res[:, i::8] & 0xF) << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
self.B[:, :] = q.to(self.B.device)
if self.groupsize != self.k:
self.s_group[:, :] = s.to(self.s_group.device)
self.s_channel[:, :] = s_extra.to(self.s_channel.device)
else:
self.s_group = torch.tensor([], dtype=torch.half, device=self.s_channel.device)
self.s_channel[:, :] = s.to(self.s_channel.device)


def replace_linear(module, name_filter=lambda n: True, groupsize=-1, name=''):
"""Recursively replace all `torch.nn.Linear` layers by empty Marlin layers.
Expand Down
Loading