diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index 815dc3575..2b70beea6 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -95,6 +95,17 @@ jobs: python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti cat ./output_aoti + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" + python generate.py --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --compile --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + echo "tests complete" echo "******************************************" # echo "********* EAGER vs TORCH.COMPILE *********" diff --git a/quantize.py b/quantize.py index 280e17810..0b89a1641 100644 --- a/quantize.py +++ b/quantize.py @@ -53,6 +53,12 @@ def quantize_model(model: nn.Module, quantize_options): **q_kwargs ).quantized_model() elif quantizer == "linear:int4": + linears_quantized = True + model = WeightOnlyInt4QuantHandler( + model, + **q_kwargs + ).quantized_model() + elif quantizer == "linear:a8w4dq": linears_quantized = True model = Int8DynActInt4WeightQuantHandler( model, @@ -70,6 +76,9 @@ def quantize_model(model: nn.Module, quantize_options): assert 0 == 1, f"quantizer {quantizer} not supported" +######################################################################### +##### Quantization Primitives ###### + def dynamically_quantize_per_channel( x, quant_min, @@ -164,6 +173,115 @@ def dynamically_quantize_per_channel( return quant, scales, zero_points + +def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype= torch.float): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(scales_dtype).reshape(w.shape[0], -1), zeros.to( + scales_dtype + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros, *, scales_dtype=torch.float): + assert scales.shape == zeros.shape + assert scales.dtype == scales_dtype + assert zeros.dtype == scales_dtype + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + +######################################################################### + class QuantHandler: def __init__(self, mod): self.mod = mod @@ -173,6 +291,12 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict" def convert_for_runtime(self) -> nn.Module: pass + + def quantized_model(self) -> nn.Module: + model_updated_state_dict = self.create_quantized_state_dict() + self.convert_for_runtime() + self.mod.load_state_dict(model_updated_state_dict) + return self.mod ##### Weight-only int8 per-channel quantized code ###### @@ -202,7 +326,7 @@ def replace_linear_weight_only_int8_per_channel(module, node_type, group_size=No replace_linear_weight_only_int8_per_channel(child, node_type, group_size) -class WeightOnlyInt8QuantHandler: +class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, mod, @@ -349,7 +473,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel( ) -class EmbeddingOnlyInt8QuantHandler: +class EmbeddingOnlyInt8QuantHandler(QuantHandler): def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None): self.mod = mod self.group_size = group_size @@ -466,6 +590,145 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ################################################################## ##### weight only int4 per channel groupwise quantized code ###### +def _int4_prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + return weight_int4pack, scales_and_zeros + +def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1): + from model import find_multiple + return find_multiple(k, 1024) + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x.to(dtype=torch.bfloat16), + weight_int4pack, + groupsize, + scales_and_zeros.to(dtype=torch.bfloat16) + ).to(dtype=x.dtype) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _int4_check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=False): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) + + +class WeightOnlyInt4QuantHandler(QuantHandler): + def __init__(self, mod, group_size=128, inner_k_tiles=8, padding_allowed=True): + self.mod = mod + self.groupsize = group_size + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + assert group_size in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _int4_check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if self.padding_allowed: + from model import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.float), self.groupsize, self.inner_k_tiles + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + + return cur_state_dict + + def convert_for_runtime(self, use_cuda=False): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) + return self.mod + + def quantized_model(self) -> nn.Module: + model_updated_state_dict = self.create_quantized_state_dict() + self.convert_for_runtime() + self.mod.load_state_dict(model_updated_state_dict) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + ) -> None: + super().__init__() + self.padding = not _int4_check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + # MKG: torch.float + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.float) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # MKG torch.float + input = input.to(torch.float) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + +######################################################################## +### Int8 Dynamic Activations 4 Bit Weights def prepare_int4_weight_and_scales_and_zeros(weight, group_size, precision): weight_int8, scales, zeros = group_quantize_tensor_symmetric( @@ -523,7 +786,6 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: def _check_linear_int4_k(k, group_size=1): return k % group_size == 0 - def _calc_padded_size_linear_int4(k, groupsize=1): return find_multiple(k, groupsize) @@ -560,7 +822,7 @@ def replace_linear_8da4w( ) -class Int8DynActInt4WeightQuantHandler: +class Int8DynActInt4WeightQuantHandler(QuantHandler): def __init__( self, mod,