Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hqq Integration: dequant kernel #507

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/hqq_dequant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
## HQQ Integration: dequant kernel

Standalone asymmetric dequant kernel for `hqq` quantization as a first step towards integrating `hqq` as an alternative quantization backend.

Supports `hqq` [`BaseQuantConfig`](https://github.com/mobiusml/hqq/blob/aad68687e042ed628b5a655969406d501a203949/hqq/core/quantize.py#L872-L935) settings currently:
- `nbits` = `{4, 8}`
- Quantization bits, `{1, 2, 3} bits` not yet supported
- `axis` = `{0, 1}`
- Axis along which weights are quantized
- Anecdotal evidence of better accuracy with `axis=0`
- Not all built-in `hqq` dequant implementations are available for both axis -- this kernel supports both.
- `group_size`
- Grouping size of weights during quantization
- The kernel should work for any (power of 2) group sizes, but tested only for common sizes (`64`, `128`).
- manual and `autotune` kernels, which should ease downstream interoperability with `torch.compile`.
- `quant_zero`
- Additional quantization of the zeropoints
- Currently only supports `nbit=8` scalar scale / zero quantization of the zeros, which is the default setting of [`hqq.BaseQuantizeConfig`](https://github.com/mobiusml/hqq/blob/aad68687e042ed628b5a655969406d501a203949/hqq/core/quantize.py#L920-L924).
- `quant_scale`
- Additional quantization of the scales
- Not supported currently, as the default setting for [`hqq.BaseQuantizeConfig`](https://github.com/mobiusml/hqq/blob/aad68687e042ed628b5a655969406d501a203949/hqq/core/quantize.py#L876) is `quant_scale=False` (scales are not additionally quantized).

## Accuracy
See `test_hqq_dequant.py` for comprehensive tests across `dtypes`, `group_sizes`, `axis`, and other relevant params.

Run with
```
pytest -sv test_hqq_dequant.py`
```

## Performance
Please take with grain of salt, as I only benched against `HQQBackend.PYTORCH` on my laptop (RTX 3050):

```
python benchmark_hqq_dequant.py
```

| shape | axis | group_size | nbits | dtype | quant_scale | quant_zero | block_size | hqq(HQQBackend.PYTORCH) | triton | speedup |
|-------------|------|------------|-------|----------------|-------------|------------|------------|-------------------------|---------|---------|
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 32 | 15.3904 | 2.3977 | 6.42x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 64 | 15.3313 | 2.3957 | 6.40x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 128 | 15.3985 | 2.3967 | 6.42x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 256 | 15.4044 | 2.3986 | 6.42x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 512 | 15.4192 | 2.4153 | 6.38x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | 1024 | 15.4055 | 25.1655 | 0.61x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | False | autotune | 15.3446 | 2.3976 | 6.40x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 32 | 15.5533 | 2.3839 | 6.52x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 64 | 15.6986 | 2.3869 | 6.58x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 128 | 15.5906 | 2.3807 | 6.55x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 256 | 15.6426 | 2.3936 | 6.54x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 512 | 15.5842 | 2.4072 | 6.47x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | 1024 | 15.6129 | 38.3974 | 0.41x |
| (4096, 4096) | 1 | 64 | 4 | torch.bfloat16 | False | True | autotune | 15.5552 | 2.3805 | 6.53x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 32 | 15.3647 | 2.3708 | 6.48x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 64 | 15.4205 | 2.3707 | 6.50x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 128 | 15.3875 | 2.3736 | 6.48x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 256 | 15.4178 | 2.3885 | 6.45x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 512 | 15.3764 | 5.5952 | 2.75x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | 1024 | 15.3659 | 28.3112 | 0.54x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | False | autotune | 15.3566 | 2.3720 | 6.47x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 32 | 15.4933 | 2.3652 | 6.55x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 64 | 15.6100 | 2.3629 | 6.61x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 128 | 15.5169 | 2.3707 | 6.55x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 256 | 15.5769 | 2.3819 | 6.54x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 512 | 15.5484 | 46.7231 | 0.33x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | 1024 | 15.4976 | 39.2632 | 0.39x |
| (4096, 4096) | 1 | 128 | 4 | torch.bfloat16 | False | True | autotune | 15.5105 | 2.3612 | 6.57x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 32 | 17.7245 | 2.3934 | 7.41x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 64 | 17.7356 | 2.3985 | 7.39x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 128 | 17.7039 | 2.3962 | 7.39x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 256 | 17.7170 | 2.4007 | 7.38x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 512 | 17.7893 | 2.4305 | 7.32x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | 1024 | 17.7887 | 3.4368 | 5.18x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | False | autotune | 17.8211 | 2.3958 | 7.44x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 32 | 17.9001 | 2.3820 | 7.51x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 64 | 18.0115 | 2.3831 | 7.56x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 128 | 17.9640 | 2.3884 | 7.52x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 256 | 17.9970 | 2.3892 | 7.53x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 512 | 17.9618 | 2.4060 | 7.47x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | 1024 | 18.0256 | 41.0300 | 0.44x |
| (4096, 4096) | 0 | 64 | 4 | torch.bfloat16 | False | True | autotune | 18.0029 | 2.3838 | 7.55x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 32 | 15.3639 | 2.3799 | 6.46x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 64 | 15.4093 | 2.3827 | 6.47x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 128 | 15.3549 | 2.3800 | 6.45x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 256 | 15.4489 | 2.3996 | 6.44x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 512 | 15.3766 | 3.7026 | 4.15x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | 1024 | 15.4355 | 26.2775 | 0.59x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | False | autotune | 15.3563 | 2.3682 | 6.48x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 32 | 15.6545 | 2.3809 | 6.58x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 64 | 15.5018 | 2.3688 | 6.54x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 128 | 15.5865 | 2.3731 | 6.57x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 256 | 15.5484 | 2.3861 | 6.52x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 512 | 15.6000 | 44.5326 | 0.35x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | 1024 | 15.5037 | 41.6425 | 0.37x |
| (4096, 4096) | 0 | 128 | 4 | torch.bfloat16 | False | True | autotune | 15.5015 | 2.3781 | 6.52x |


## Notes
The kernel requires `triton >= 3.0.0` which is not compatible with stable `xformers`:
- This required fixing the `triton` import `unsloth.__init__.py` per this [PR](https://github.com/unslothai/unsloth/pull/227).
- Initially tried to add the kernels under `unsloth.kernels` but `import xformers` from `unsloth.models.__init__.py` errors out due to `xformers` `triton` kernels incompatible with `triton >= 3.0.0`.
- Note that `xformers` is technically not required with `torch >= 2.3` since `xformers.attn_bias.LowerTriangularMask` is available under `torch.nn.attention.bias`.
162 changes: 162 additions & 0 deletions tests/hqq_dequant/benchmark_hqq_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import itertools

import torch
from dequant import hqq_dequant
from hqq.core.quantize import HQQBackend, HQQLinear, Quantizer, hqq_base_quant_config
from tabulate import tabulate
from triton.testing import do_bench
from utils import patch_hqq_packing

torch.manual_seed(0)

SHAPES = [(4096, 4096)]
AXES = [1, 0]
GROUP_SIZES = [64, 128]
NBITS = [4]
DTYPES = [torch.bfloat16]
QUANT_SCALES = [False]
QUANT_ZEROS = [False, True]
NUM_ITERS = [10]
BLOCK_SIZES = [32, 64, 128, 256, 512, 1024, "autotune"]
device = "cuda"
backend = HQQBackend.PYTORCH
HQQLinear.set_backend(backend)


BENCH_CONFIGS = list(
itertools.product(
SHAPES,
AXES,
GROUP_SIZES,
NBITS,
DTYPES,
QUANT_SCALES,
QUANT_ZEROS,
BLOCK_SIZES,
NUM_ITERS,
)
)


def run_bench(fn, num_iters):
times = do_bench(lambda: [fn() for _ in range(num_iters)])
return times


data = []

for (
shape,
axis,
group_size,
nbits,
dtype,
quant_scale,
quant_zero,
block_size,
num_iters,
) in BENCH_CONFIGS[:]:
if backend == HQQBackend.ATEN and axis == 1:
continue
M, N = shape

linear = torch.nn.Linear(M, N, dtype=dtype, device=device)

quant_cfg = hqq_base_quant_config(
nbits=nbits, group_size=group_size, axis=axis, quant_zero=quant_zero
)
hqq_linear = HQQLinear(
linear,
device=device,
quant_config=quant_cfg,
compute_dtype=dtype,
del_orig=False,
)
if quant_zero:
zero_q = hqq_linear.meta["zero_q"]
meta_zero = hqq_linear.meta["meta_zero"]
z_ref = Quantizer.dequantize(zero_q, meta_zero)

HEADERS = [
"shape",
"axis",
"group_size",
"nbits",
"dtype",
"quant_scale",
"quant_zero",
"block_size",
f"hqq({backend})",
"triton",
"speedup",
]

common_args = [
shape,
axis,
group_size,
nbits,
dtype,
quant_scale,
quant_zero,
]

ref_fn = lambda: hqq_linear.dequantize()
ref_t = run_bench(ref_fn, num_iters)

with patch_hqq_packing():
hqq_linear = HQQLinear(
linear,
device=device,
quant_config=quant_cfg,
compute_dtype=dtype,
del_orig=False,
)
q = hqq_linear.W_q

meta = hqq_linear.meta
output_shape = meta["shape"]

scales = meta["scale"].view(-1)
if quant_zero:
zeros = meta["zero_q"].view(-1)
meta_zero = meta["meta_zero"]
qz_scale = meta_zero["scale"]
qz_zero = meta_zero["zero"]
else:
zeros = meta["zero"].view(-1)
qz_scale, qz_zero = None, None

num_groups = len(scales)
assert num_groups == (M * N) // group_size

autotune = block_size == "autotune"
block_size = min(block_size, num_groups) if not autotune else None

test_fn = lambda: hqq_dequant(
q,
scales=scales,
zeros=zeros,
group_size=group_size,
nbits=nbits,
axis=axis,
output_shape=output_shape,
quant_zero=quant_zero,
qz_scales=qz_scale,
qz_zeros=qz_zero,
block_size=block_size,
autotune=autotune,
)

test_t = run_bench(test_fn, num_iters=num_iters)
data.append(
[
*common_args,
"autotune" if autotune else block_size,
ref_t,
test_t,
f"{ref_t / test_t:.2f}x",
]
)

print(tabulate(data, headers=HEADERS, floatfmt=".4f"))
Loading