Skip to content

Commit

Permalink
more groupsize==None/0 handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 9, 2024
1 parent 9bdfe0a commit a90ab63
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ def replace_linear_8da4w(
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if group_size is None or group_size == 0:
child_group_size = child.in_features
else:
child_group_size = group_size
if _check_linear_int4_k(child.in_features, group_size) or padding_allowed:
setattr(
module,
Expand All @@ -545,7 +549,7 @@ def replace_linear_8da4w(
child.in_features,
child.out_features,
bias=False,
group_size=group_size,
group_size=child_group_size,
precision=precision,
scales_precision=scales_precision,
),
Expand Down Expand Up @@ -585,25 +589,27 @@ def create_quantized_state_dict(self):
in_features = mod.in_features
if group_size is None or group_size == 0:
group_size = in_features
else:
group_size = self.group_size
# print("in features:", in_features, " out features:", out_features)
# assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.group_size == 0
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
in_features % group_size == 0
), f"require in_features:{in_features} % group_size:{group_size} == 0"

weight = mod.weight.data
"""
if not _check_linear_int4_k(
in_features, self.group_size
in_features, group_size
):
if self.padding_allowed:
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = _calc_padded_size_linear_int4(
in_features, self.group_size
in_features, group_size
)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
Expand All @@ -620,7 +626,7 @@ def create_quantized_state_dict(self):
zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(self.precision),
self.group_size,
group_size,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
Expand Down

0 comments on commit a90ab63

Please sign in to comment.