Skip to content

Commit

Permalink
bug fix for FP16(BF16 maybe incorrect)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 18, 2024
1 parent a590ad6 commit 78433cb
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.testing._internal.common_dtype import floating_and_complex_types_and
from torch.testing._internal.common_cuda import tf32_on_and_off
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
from torch.testing._internal.common_quantization import _dynamically_quantize_per_channel
from torch.testing import make_tensor
import unittest
import itertools
Expand Down Expand Up @@ -174,9 +174,53 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True):

@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@parametrize("m", [1])
@parametrize("k", [1024, 2048])
@parametrize("n", [48, 64])
@parametrize("k", [32])
@parametrize("n", [32])
def _int4_mm(self, device, m, k, n):
def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
assert w.dim() == 2
w = w.transpose(0, 1).contiguous()
assert q_group_size > 1
assert w.shape[-1] % q_group_size == 0

to_quant = w.reshape(-1, q_group_size)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2 ** n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-6) / max_int
assert torch.isnan(scales).sum() == 0

zeros = min_val + scales * (2 ** (n_bit - 1))
assert torch.isnan(zeros).sum() == 0

out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
assert torch.isnan(out).sum() == 0

out = out.to(dtype=torch.int32).reshape(w.shape)
if out.device.type != 'cpu' or out.device.type != 'xpu':
out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8)

# Scales and zeros for the same q-group should be contiguous, so we can
# load as a 32-bit word
scales = scales.view(w.shape[0], -1)
zeros = zeros.view(w.shape[0], -1)
scales_and_zeros = (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
)

if out.device.type != 'xpu':
scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous()
return out, scales_and_zeros

def convert_weight_to_int4pack(b):
b_tmp, b_scales_and_zeros = _group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
Expand Down Expand Up @@ -231,7 +275,8 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)

print(ref)
print(res)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)

Expand Down

0 comments on commit 78433cb

Please sign in to comment.