Skip to content

Commit

Permalink
[TESTS][zeta.quant]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 23, 2023
1 parent 05f20f5 commit d09b343
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 25 deletions.
Empty file added tests/quant/qmoe.py
Empty file.
38 changes: 38 additions & 0 deletions tests/quant/test_bitlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch
from torch import nn
from zeta.quant.bitlinear import BitLinear, absmax_quantize


def test_bitlinear_reset_parameters():
bitlinear = BitLinear(10, 20)
old_weight = bitlinear.weight.clone()
bitlinear.reset_parameters()

assert not torch.equal(old_weight, bitlinear.weight)


def test_bitlinear_forward_quantization():
bitlinear = BitLinear(10, 20)
input = torch.randn(128, 10)
output = bitlinear(input)

assert isinstance(output, torch.Tensor)
assert output.shape == (128, 20)

# Check that the output is different from the input, indicating that quantization and dequantization occurred
assert not torch.allclose(output, input)


@pytest.mark.parametrize("bits", [4, 8, 16])
def test_absmax_quantize_different_bits(bits):
x = torch.tensor([1.0, -2.0, 3.0, -4.0])
quant, dequant = absmax_quantize(x, bits)

assert isinstance(quant, torch.Tensor)
assert quant.dtype == torch.int8
assert torch.allclose(dequant, x, atol=1e-2)

# Check that the quantized values are within the expected range
assert quant.min() >= -(2 ** (bits - 1))
assert quant.max() <= 2 ** (bits - 1) - 1
55 changes: 55 additions & 0 deletions tests/quant/test_quik.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import torch
from torch import nn
from zeta.quant.quick import QUIK


def test_quik_initialization():
quik = QUIK(10, 20)

assert isinstance(quik, QUIK)
assert quik.in_features == 10
assert quik.out_features == 20
assert quik.quantize_range == 8
assert quik.half_range == 4
assert quik.weight.shape == (20, 10)
assert quik.bias.shape == (20,)


def test_quik_quantize():
quik = QUIK(10, 20)
x = torch.randn(10, 10)
quant_x, zero_act, scale_act = quik.quantize(x)

assert isinstance(quant_x, torch.Tensor)
assert quant_x.dtype == torch.int32
assert isinstance(zero_act, torch.Tensor)
assert isinstance(scale_act, torch.Tensor)


def test_quik_dequantize():
quik = QUIK(10, 20)
x = torch.randn(10, 10)
quant_x, zero_act, scale_act = quik.quantize(x)
dequant_x = quik.dequantize(quant_x, zero_act, scale_act, scale_act)

assert isinstance(dequant_x, torch.Tensor)
assert dequant_x.dtype == torch.float32


def test_quik_find_zero_scale():
quik = QUIK(10, 20)
x = torch.randn(10, 10)
zero_act, scale_act = quik.find_zero_scale(x)

assert isinstance(zero_act, torch.Tensor)
assert isinstance(scale_act, torch.Tensor)


def test_quik_forward():
quik = QUIK(10, 20)
x = torch.randn(10, 10)
output = quik(x)

assert isinstance(output, torch.Tensor)
assert output.shape == (10, 20)
25 changes: 0 additions & 25 deletions zeta/quant/qmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,3 @@ def forward(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x


if __name__ == "__main__":
import time

D = 2048
K = 8

torch.random.manual_seed(0)
X = torch.randn(128, 512, D).cuda()
W = torch.randn(K, 768, D).cuda()
quantizer = QMOEQuantizer()
quantizer.configure(2)

H = hessian(X).repeat(K, 1, 1)
Q = batch_gptq(W, H, quantizer)
tick = time.time()
COUNT = 10
for i in range(COUNT):
H = hessian(X).repeat(K, 1, 1)
Q = batch_gptq(W, H, quantizer)
torch.cuda.synchronize()
print((time.time() - tick) / COUNT)

print(Q[0])

0 comments on commit d09b343

Please sign in to comment.