Skip to content

Commit

Permalink
restore llama-fast version of dyn_quant_per_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 10, 2024
1 parent 745e47f commit d405e8a
Showing 1 changed file with 64 additions and 3 deletions.
67 changes: 64 additions & 3 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,71 @@ def quantize_model(model: nn.Module, quantize_options):
#########################################################################
##### Quantization Primitives ######

def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
group_size: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
group_size: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""

# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

x_shape_1 = x.shape[1]

if group_size is None or group_size == 0:
items = x_shape_1
elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups:
assert group_size > 0, "group size must be positive"
assert (
x_shape_1 % group_size
) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}"
items = group_size
else:
assert group_size > 0, "group size must be positive"
print(
f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding"
)
assert (
x_shape_1 % group_size != 0
), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}"
padding = group_size - (x_shape_1 % group_size)
x = F.pad(x, (0, padding))
items = group_size

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")

# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
Expand All @@ -109,10 +163,17 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
quant = (
torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
)

scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]

return quant, scales, zero_points



def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
Expand Down

0 comments on commit d405e8a

Please sign in to comment.