From d09b3433fc048869f83ec9b0da20939485020b26 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 23 Dec 2023 00:29:55 -0500 Subject: [PATCH] [TESTS][zeta.quant] --- tests/quant/qmoe.py | 0 tests/quant/test_bitlinear.py | 38 ++++++++++++++++++++++++ tests/quant/test_quik.py | 55 +++++++++++++++++++++++++++++++++++ zeta/quant/qmoe.py | 25 ---------------- 4 files changed, 93 insertions(+), 25 deletions(-) create mode 100644 tests/quant/qmoe.py create mode 100644 tests/quant/test_bitlinear.py create mode 100644 tests/quant/test_quik.py diff --git a/tests/quant/qmoe.py b/tests/quant/qmoe.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/quant/test_bitlinear.py b/tests/quant/test_bitlinear.py new file mode 100644 index 00000000..64467687 --- /dev/null +++ b/tests/quant/test_bitlinear.py @@ -0,0 +1,38 @@ +import pytest +import torch +from torch import nn +from zeta.quant.bitlinear import BitLinear, absmax_quantize + + +def test_bitlinear_reset_parameters(): + bitlinear = BitLinear(10, 20) + old_weight = bitlinear.weight.clone() + bitlinear.reset_parameters() + + assert not torch.equal(old_weight, bitlinear.weight) + + +def test_bitlinear_forward_quantization(): + bitlinear = BitLinear(10, 20) + input = torch.randn(128, 10) + output = bitlinear(input) + + assert isinstance(output, torch.Tensor) + assert output.shape == (128, 20) + + # Check that the output is different from the input, indicating that quantization and dequantization occurred + assert not torch.allclose(output, input) + + +@pytest.mark.parametrize("bits", [4, 8, 16]) +def test_absmax_quantize_different_bits(bits): + x = torch.tensor([1.0, -2.0, 3.0, -4.0]) + quant, dequant = absmax_quantize(x, bits) + + assert isinstance(quant, torch.Tensor) + assert quant.dtype == torch.int8 + assert torch.allclose(dequant, x, atol=1e-2) + + # Check that the quantized values are within the expected range + assert quant.min() >= -(2 ** (bits - 1)) + assert quant.max() <= 2 ** (bits - 1) - 1 diff --git a/tests/quant/test_quik.py b/tests/quant/test_quik.py new file mode 100644 index 00000000..df87bcb8 --- /dev/null +++ b/tests/quant/test_quik.py @@ -0,0 +1,55 @@ +import pytest +import torch +from torch import nn +from zeta.quant.quick import QUIK + + +def test_quik_initialization(): + quik = QUIK(10, 20) + + assert isinstance(quik, QUIK) + assert quik.in_features == 10 + assert quik.out_features == 20 + assert quik.quantize_range == 8 + assert quik.half_range == 4 + assert quik.weight.shape == (20, 10) + assert quik.bias.shape == (20,) + + +def test_quik_quantize(): + quik = QUIK(10, 20) + x = torch.randn(10, 10) + quant_x, zero_act, scale_act = quik.quantize(x) + + assert isinstance(quant_x, torch.Tensor) + assert quant_x.dtype == torch.int32 + assert isinstance(zero_act, torch.Tensor) + assert isinstance(scale_act, torch.Tensor) + + +def test_quik_dequantize(): + quik = QUIK(10, 20) + x = torch.randn(10, 10) + quant_x, zero_act, scale_act = quik.quantize(x) + dequant_x = quik.dequantize(quant_x, zero_act, scale_act, scale_act) + + assert isinstance(dequant_x, torch.Tensor) + assert dequant_x.dtype == torch.float32 + + +def test_quik_find_zero_scale(): + quik = QUIK(10, 20) + x = torch.randn(10, 10) + zero_act, scale_act = quik.find_zero_scale(x) + + assert isinstance(zero_act, torch.Tensor) + assert isinstance(scale_act, torch.Tensor) + + +def test_quik_forward(): + quik = QUIK(10, 20) + x = torch.randn(10, 10) + output = quik(x) + + assert isinstance(output, torch.Tensor) + assert output.shape == (10, 20) diff --git a/zeta/quant/qmoe.py b/zeta/quant/qmoe.py index 90a72daa..e575b1e8 100644 --- a/zeta/quant/qmoe.py +++ b/zeta/quant/qmoe.py @@ -225,28 +225,3 @@ def forward(self, x): if self.ready(): return quantize(x, self.scale, self.zero, self.maxq) return x - - -if __name__ == "__main__": - import time - - D = 2048 - K = 8 - - torch.random.manual_seed(0) - X = torch.randn(128, 512, D).cuda() - W = torch.randn(K, 768, D).cuda() - quantizer = QMOEQuantizer() - quantizer.configure(2) - - H = hessian(X).repeat(K, 1, 1) - Q = batch_gptq(W, H, quantizer) - tick = time.time() - COUNT = 10 - for i in range(COUNT): - H = hessian(X).repeat(K, 1, 1) - Q = batch_gptq(W, H, quantizer) - torch.cuda.synchronize() - print((time.time() - tick) / COUNT) - - print(Q[0])