diff --git a/export.py b/export.py index dbcc6e87c..30a698e0c 100644 --- a/export.py +++ b/export.py @@ -58,7 +58,7 @@ def forward(self, idx, input_pos): return logits # sample(logits, **sampling_kwargs) -def main(checkpoint_path, device, quantize = "{ }", args = None): +def main(checkpoint_path, device, args = None): assert checkpoint_path.is_file(), checkpoint_path print(f"Using device={device}") @@ -72,12 +72,12 @@ def main(checkpoint_path, device, quantize = "{ }", args = None): device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") - quantize_model(model, args.quantize) + quantize_model(model, args) # dtype: if args.dtype: model.to(dtype=name_to_dtype(args.dtype)) - + model = model_wrapper(model, device=device) output_pte_path = args.output_pte_path @@ -180,7 +180,7 @@ def cli(): args = parser.parse_args() - main(args.checkpoint_path, args.device, args.quantize, args) + main(args.checkpoint_path, args.device, args) if __name__ == "__main__": cli() diff --git a/generate.py b/generate.py index b28e5c7fc..d95dbc34f 100644 --- a/generate.py +++ b/generate.py @@ -198,6 +198,7 @@ def generate( draft_model: Transformer, speculate_k: Optional[int] = 8, callback=lambda x: x, + precision=torch.float, **sampling_kwargs, ) -> torch.Tensor: """ @@ -214,13 +215,14 @@ def generate( max_seq_length = min(T_new, model.config.block_size) device, dtype = prompt.device, prompt.dtype + model = model.to(device) max_seq_length = ( max_seq_length + speculate_k + 1 if is_speculative else max_seq_length ) with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length, dtype=precision) if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length, dtype=precision) # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(T_new, dtype=dtype, device=device) @@ -315,8 +317,8 @@ def main( device="cuda", dso_path=None, pte_path=None, - quantize=None, model_dtype=None, + args=None, ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer.""" assert ( @@ -344,7 +346,7 @@ def main( # print = lambda *args, **kwargs: None print(f"Using device={device}") - precision = torch.float # bfloat16 + precision = torch.float is_speculative = draft_checkpoint_path is not None is_chat = "chat" in str(checkpoint_path) @@ -377,17 +379,20 @@ def main( model = model_ # Add new CLI arg - if quantize: + if args.quantize: + with torch.device(device): + # TODO: fix max_seq_length + model.setup_caches(max_batch_size=1, max_seq_length=2048, dtype=precision) device_sync(device=device) t0q = time.time() - quantize_model(model, quantize) + quantize_model(model, args) device_sync(device=device) # MKG print(f"Time to quantize model: {time.time() - t0q:.02f} seconds") # dtype: if model_dtype: - model.to(dtype=name_to_dtype(model_dtype)) - + model = model.to(dtype=name_to_dtype(model_dtype)) + if is_speculative: draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) else: @@ -480,6 +485,7 @@ def callback(x): callback=callback, temperature=temperature, top_k=top_k, + precision=name_to_dtype(model_dtype), ) aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) if i == -1: @@ -610,7 +616,7 @@ def cli(): args = parser.parse_args() if args.seed: - torch.manual_seed(args.seed) + torch.manual_seed(args.seed) main( args.prompt, @@ -629,8 +635,8 @@ def cli(): args.device, args.dso_path, args.pte_path, - args.quantize, args.dtype, + args, ) if __name__ == "__main__": diff --git a/model.py b/model.py index 972c1f736..8c40167f7 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,22 @@ from torch.nn import functional as F +def prepare_inputs_for_model(inps, max_new_tokens=1): + # this is because input from lm-eval is 2d + if inps.dim() != 2: + raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}") + + inps = inps.squeeze(0) + # setup inputs in correct format + T = inps.size(0) + T_new = T + max_new_tokens + seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device) + seq[:T] = inps + input_pos = torch.arange(0, T, device=inps.device) + x = seq.index_select(0, input_pos).view(1, -1) + return (x, input_pos) + + def find_multiple(n: int, k: int) -> int: if n % k == 0: return n @@ -134,7 +150,7 @@ def __init__(self, config: ModelArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 - def setup_caches(self, max_batch_size, max_seq_length): + def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float): if ( self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size @@ -146,7 +162,7 @@ def setup_caches(self, max_batch_size, max_seq_length): self.max_batch_size = max_batch_size for b in self.layers: b.attention.kv_cache = KVCache( - max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype ) freqs_cis = precompute_freqs_cis( @@ -218,7 +234,7 @@ def __init__(self, config: ModelArgs): self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim - # self._register_load_state_dict_pre_hook(self.load_hook) + self._register_load_state_dict_pre_hook(self.load_hook) # def load_hook(self, state_dict, prefix, *args): # if prefix + "wq.weight" in state_dict: @@ -227,6 +243,16 @@ def __init__(self, config: ModelArgs): # wv = state_dict.pop(prefix + "wv.weight") # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + def load_hook(self, state_dict, prefix, *args): + if prefix + "wqkv.weight" in state_dict: + wqkv = state_dict.pop(prefix + "wqkv.weight") + q_size = self.n_head * self.head_dim + kv_size = self.n_local_heads * self.head_dim + wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0) + state_dict[prefix + "wq.weight"] = wq + state_dict[prefix + "wk.weight"] = wk + state_dict[prefix + "wv.weight"] = wv + def forward( self, x: Tensor, diff --git a/quantize.py b/quantize.py index c2093cfba..3064d17ff 100644 --- a/quantize.py +++ b/quantize.py @@ -6,13 +6,28 @@ from functools import reduce from math import gcd -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Any import json import torch import torch.nn as nn import torch.nn.functional as F +from torchao.quantization.quant_api import ( + apply_weight_only_int8_quant, + Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, + Quantizer, +) + + +class Int8WeightOnlyQuantizer(Quantizer): + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + apply_weight_only_int8_quant(model) + return model + try: from GPTQ import GenericGPTQRunner, InputRecorder @@ -28,7 +43,7 @@ def name_to_dtype(name): return name_to_dtype_dict[name] else: raise RuntimeError("unsupported dtype specified") - + name_to_dtype_dict = { "fp32" : torch.float, "fp16" : torch.float16, @@ -38,7 +53,7 @@ def name_to_dtype(name): ########################################################################## ### process quantization dictionary ### -def quantize_model(model: nn.Module, quantize_options): +def quantize_model(model: nn.Module, args): """ Quantize the specified model using the quantizers described by a quantization dict of the form: @@ -49,54 +64,77 @@ def quantize_model(model: nn.Module, quantize_options): } """ + quantize_options = args.quantize linears_quantized = False if isinstance(quantize_options, str): quantize_options = json.loads(quantize_options) - - for quantizer, q_kwargs in quantize_options.items(): - if quantizer == "embedding": + + for qmode, q_kwargs in quantize_options.items(): + if qmode == "embedding": model = EmbeddingOnlyInt8QuantHandler( model, **q_kwargs ).quantized_model() elif linears_quantized: - assert 0==1, "can only specify one linear quantizer" - elif quantizer == "linear:int8": + assert 0==1, "can only specify one linear qmode" + elif qmode == "linear:int8": linears_quantized = True - model = WeightOnlyInt8QuantHandler( - model, - **q_kwargs - ).quantized_model() - elif quantizer == "linear:int4": + quantizer = Int8WeightOnlyQuantizer(**q_kwargs) + model = quantizer.quantize(model) + elif qmode == "linear:8da4w": linears_quantized = True - model = WeightOnlyInt4QuantHandler( - model, - **q_kwargs - ).quantized_model() - elif quantizer == "linear:a8w4dq": + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + quantizer = Int8DynActInt4WeightQuantizer(**q_kwargs) + model = quantizer.quantize(model) + elif qmode == "linear:int4": linears_quantized = True - model = Int8DynActInt4WeightQuantHandler( - model, - **q_kwargs - ).quantized_model() - elif quantizer == "linear:gptq": + quantizer = Int4WeightOnlyQuantizer(**q_kwargs) + model = quantizer.quantize(model) + elif qmode == "linear:int4-gptq": linears_quantized = True - model = WeightOnlyInt4GPTQQuantHandler( - model, - **q_kwargs - ).quantized_model() - elif quantizer == "linear:hqq": + from pathlib import Path + from sentencepiece import SentencePieceProcessor + from torchao.quantization.GPTQ import InputRecorder + from model import prepare_inputs_for_model + + checkpoint_path = Path(args.checkpoint_path) + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + blocksize = 128 + percdamp = 0.01 + calibration_tasks = ["wikitext"] + calibration_limit = 1 + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cuda", + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + quantizer = Int4WeightOnlyGPTQQuantizer(blocksize, percdamp, **q_kwargs) + model = quantizer.quantize(model, inputs) + elif qmode == "linear:hqq": linears_quantized = True model = WeightOnlyInt4HqqQuantHandler( model, **q_kwargs ).quantized_model() - elif quantizer == "precision": + elif qmode == "precision": model.to(**q_kwargs) else: - assert 0 == 1, f"quantizer {quantizer} not supported" - - + assert 0 == 1, f"qmode {qmode} not supported" + + ######################################################################### ##### Quantization Primitives ###### @@ -314,7 +352,7 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict" def convert_for_runtime(self) -> nn.Module: pass - + def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict() self.convert_for_runtime() @@ -329,7 +367,7 @@ def quantized_model(self) -> nn.Module: def replace_linear_weight_only_int8_per_channel(module, node_type, group_size=None): if group_size is not None and group_size != 0: pass # group_size = 2 ** group_size - + for name, child in module.named_children(): # print(f"name: {name}") if isinstance(child, nn.Linear): @@ -443,7 +481,7 @@ def __init__( ) -> None: super().__init__() print(f"group size: {group_size}") - + self.in_features = in_features self.out_features = out_features self.register_buffer( @@ -592,13 +630,13 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) - + # result_weights = self.weight.index_select(0, indices.view(-1)) # result_scales = self.scales.index_select(0, indices.view(-1)) weight = self.weight scales = self.scales.view(weight.shape[0], -1) - + result_weights = F.embedding(indices, weight) result_scales = F.embedding(indices, scales) @@ -609,10 +647,10 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: r = rw_view * rs_view return r.view(indices.size() + (-1,)) - + # r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, )) - + ######################################################################### ##### weight only int4 per channel groupwise quantized code ###### @@ -1196,7 +1234,7 @@ def quantized_model(self) -> nn.Module: self.convert_for_runtime() self.mod.load_state_dict(model_updated_state_dict) return self.mod - + # class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler): @@ -1281,7 +1319,7 @@ def __init__(self, mod, group_size): def create_quantized_state_dict(self): from hqq.core.quantize import Quantizer # TODO maybe torchao - + for m in self.mod.modules(): for name, child in m.named_children(): if isinstance(child, torch.nn.Linear): @@ -1309,7 +1347,7 @@ def convert_for_runtime(self): return WeightOnlyInt4GPTQQuantHandler( self.mod, bitwidth=4, group_size=self.groupsize ).convert_for_runtime() - + def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict() self.convert_for_runtime() @@ -1318,4 +1356,3 @@ def quantized_model(self) -> nn.Module: ################################################################## - diff --git a/requirements.txt b/requirements.txt index 2f1a2a7cc..302ecd0c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch sentencepiece numpy +torchao==0.1