From 7c62a14c0e55fbc0644ab9494f2793ba38b17e5a Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 9 Apr 2024 19:03:48 -0700 Subject: [PATCH] all ops to int4_mm are bfloat16 --- quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantize.py b/quantize.py index 4a641e13a..985f66200 100644 --- a/quantize.py +++ b/quantize.py @@ -604,7 +604,7 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1): def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + c = torch.ops.aten._weight_int4pack_mm(x.to(dtype=torch.bfloat16), weight_int4pack, groupsize, scales_and_zeros.to(dtype=torch.bfloat16).todtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c