diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 012d30f98..206f362bf 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,7 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off -from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel +from torch.testing._internal.common_quantization import _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -174,9 +174,53 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @parametrize("m", [1]) -@parametrize("k", [1024, 2048]) -@parametrize("n", [48, 64]) +@parametrize("k", [32]) +@parametrize("n", [32]) def _int4_mm(self, device, m, k, n): + def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device.type != 'cpu' or out.device.type != 'xpu': + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + ) + + if out.device.type != 'xpu': + scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() + return out, scales_and_zeros + def convert_weight_to_int4pack(b): b_tmp, b_scales_and_zeros = _group_quantize_tensor( b, n_bit=4, q_group_size=q_group @@ -231,7 +275,8 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) ref = torch.mm(a, b) res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) - + print(ref) + print(res) mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05)