Skip to content

Commit

Permalink
quantization primitizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 10, 2024
1 parent b186d76 commit 745e47f
Showing 1 changed file with 111 additions and 60 deletions.
171 changes: 111 additions & 60 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,71 +76,20 @@ def quantize_model(model: nn.Module, quantize_options):
assert 0 == 1, f"quantizer {quantizer} not supported"


def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
group_size: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
group_size: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""
#########################################################################
##### Quantization Primitives ######

def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

x_shape_1 = x.shape[1]

if group_size is None or group_size == 0:
items = x_shape_1
elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups:
assert group_size > 0, "group size must be positive"
assert (
x_shape_1 % group_size
) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}"
items = group_size
else:
assert group_size > 0, "group size must be positive"
print(
f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding"
)
assert (
x_shape_1 % group_size != 0
), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}"
padding = group_size - (x_shape_1 % group_size)
x = F.pad(x, (0, padding))
items = group_size

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
Expand All @@ -160,15 +109,117 @@ def dynamically_quantize_per_channel(
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = (
torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)

return quant, scales, zero_points

def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)


def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)

scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]

return quant, scales, zero_points
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)


def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]

assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int32


def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros


def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2

w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq


def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)

#########################################################################

class QuantHandler:
def __init__(self, mod):
Expand Down

0 comments on commit 745e47f

Please sign in to comment.