From a90ab63c23bb090cf31d3f65672e3725f027d8fd Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 9 Apr 2024 15:24:17 -0700 Subject: [PATCH] more groupsize==None/0 handling --- quantize.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/quantize.py b/quantize.py index 6299d280c..3f34c00b9 100644 --- a/quantize.py +++ b/quantize.py @@ -537,6 +537,10 @@ def replace_linear_8da4w( ): for name, child in module.named_children(): if isinstance(child, nn.Linear): + if group_size is None or group_size == 0: + child_group_size = child.in_features + else: + child_group_size = group_size if _check_linear_int4_k(child.in_features, group_size) or padding_allowed: setattr( module, @@ -545,7 +549,7 @@ def replace_linear_8da4w( child.in_features, child.out_features, bias=False, - group_size=group_size, + group_size=child_group_size, precision=precision, scales_precision=scales_precision, ), @@ -585,25 +589,27 @@ def create_quantized_state_dict(self): in_features = mod.in_features if group_size is None or group_size == 0: group_size = in_features + else: + group_size = self.group_size # print("in features:", in_features, " out features:", out_features) # assert out_features % 8 == 0, "require out_features % 8 == 0" # print(f"linear: {fqn}, in={in_features}, out={out_features}") assert ( - in_features % self.group_size == 0 - ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" + in_features % group_size == 0 + ), f"require in_features:{in_features} % group_size:{group_size} == 0" weight = mod.weight.data """ if not _check_linear_int4_k( - in_features, self.group_size + in_features, group_size ): if self.padding_allowed: print( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = _calc_padded_size_linear_int4( - in_features, self.group_size + in_features, group_size ) weight = F.pad( weight, pad=(0, padded_in_features - in_features) @@ -620,7 +626,7 @@ def create_quantized_state_dict(self): zeros, ) = prepare_int4_weight_and_scales_and_zeros( weight.to(self.precision), - self.group_size, + group_size, self.scales_precision, ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")