diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 31c639dfd..6fa8c2e39 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -36,6 +36,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa from torchao.quantization.quant_api import ( int4_weight_only, + int8_weight_only, Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer, quantize_, @@ -110,12 +111,20 @@ def quantize_model( if quantizer not in quantizer_class_dict: raise RuntimeError(f"unknown quantizer {quantizer} specified") else: + ao_quant = True # Use tensor subclass API for int4 weight only. if device == "cuda" and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) + elif quantizer == "linear:int8": + print("quantizer is linear int8") + quantize_(model, int8_weight_only()) + else: + ao_quant = False + if ao_quant: if not support_tensor_subclass: unwrap_tensor_subclass(model) continue + if quantizer in ["linear:a8wxdq", "embedding:wx"]: # These quantizers require float32 input weights. Note that after quantization, @@ -529,147 +538,6 @@ def linear_int8_et(input, weight, scales): ) -class WeightOnlyInt8Linear(nn.Module): - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - weight: torch.Tensor - scales: torch.Tensor - - def __init__( - self, - in_features, - out_features, - bias=None, - device=None, - dtype=None, - *, - weight: Optional[torch.Tensor] = None, - scales: Optional[torch.Tensor] = None, - groupsize: Optional[int] = None, - ): - super().__init__() - if dtype is None: - dtype = torch.get_default_dtype() - - if device is None: - device = "cpu" - - assert not bias, "Bias is not supported by LinearInt8" - self.in_features = in_features - self.out_features = out_features - - assert (weight is None) == bool( - scales is None - ), "must specify both weights and scales, or neither" - if weight is None: - weight = torch.empty( - (out_features, in_features), - dtype=torch.int8, - device=device, - ) - if groupsize is None or (groupsize == 0): - scales = torch.empty(out_features, dtype=dtype, device=device) - else: - n_groups = (in_features + groupsize - 1) // groupsize - scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) - - self.register_buffer("weight", weight.to(device)) - self.register_buffer("scales", scales.to(device)) - - if use_et_backend(): - self.forward = self.et_forward - else: - self.forward = self.aoti_forward - - def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: - return linear_int8_aoti(input, self.weight, self.scales) - - def et_forward(self, input: torch.Tensor) -> torch.Tensor: - return linear_int8_et(input, self.weight, self.scales) - - -class WeightOnlyInt8QuantHandler(QuantHandler): - def __init__( - self, - model: Optional[nn.Module] = None, - device = None, - precision=None, - tokenizer=None, - *, - node_type: str = "*", - bitwidth: Optional[int] = None, - groupsize: Optional[int] = None, - ): - self.model_ = model - self.device = device - self.groupsize = groupsize - self.node_type = node_type - if bitwidth is None: - self.bitwidth = 8 - else: - self.bitwidth = bitwidth - - @torch.no_grad() - def quantize(self, module): - # cur_state_dict = state_dict_device(self.model_.state_dict()) - # dict_device = "cpu" # self.device - - if self.bitwidth == 4: - range_min = -8 - range_max = 7 - elif self.bitwidth == 8: - range_min = -128 - range_max = 127 - else: - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - - for name, child in module.named_children(): - # print(f"name: {name}") - if isinstance(child, nn.Linear): - if ( - (self.node_type == "*") - or (self.node_type == "output" and name == "output") - or (self.node_type == "!output" and name != "output") - ): - # print(f"{name, child}") - input_weight = child.weight.float() - # print(f"{name, child}") - # print(f"in_features: {child.in_features}") - # print(f"out_features: {child.out_features}") - - # print(f"expanded weight shape {input_weight.shape}") - weight, scales, _ = dynamically_quantize_per_channel( - input_weight, - range_min, - range_max, - torch.int8, - self.groupsize, - scales_dtype=child.weight.dtype, - ) - - setattr( - module, - name, - WeightOnlyInt8Linear( - in_features=child.in_features, - out_features=child.out_features, - device=self.device, - # update variables from quantization - weight=weight, - scales=scales, - groupsize=self.groupsize, - ), - ) - else: - self.quantize(child) - - return module - - def quantized_model(self) -> nn.Module: - return self.quantize(self.model_) - - ######################################################################### ##### embedding table quantization ###### ### (unify with torchao in future) ### @@ -886,10 +754,10 @@ def quantized_model(self) -> nn.Module: # class references quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, - "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, "linear:int4": Int4WeightOnlyQuantizer, + "linear:int8": int8_weight_only, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, } @@ -917,6 +785,7 @@ def quantized_model(self) -> nn.Module: IntxWeightEmbeddingQuantizer, ) + quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer