diff --git a/generate.py b/generate.py index 90e8771..b09d345 100644 --- a/generate.py +++ b/generate.py @@ -1,4 +1,3 @@ -import os import sys import time from pathlib import Path @@ -7,8 +6,7 @@ import lightning as L import torch -from lit_llama.model import LLaMA -from lit_llama.tokenizer import Tokenizer +from lit_llama import LLaMA, Tokenizer, as_8_bit_quantized @torch.no_grad() @@ -104,21 +102,13 @@ def main( fabric = L.Fabric(accelerator=accelerator, devices=1) - if quantize: - from lit_llama.quantization import quantize - - print("Running quantization. This may take a minute ...") - # TODO: Initializing the model directly on the device does not work with quantization + with as_8_bit_quantized(fabric.device, enabled=quantize): + print("Loading model ...", file=sys.stderr) + t0 = time.time() model = LLaMA.from_name(model_size) - # The output layer can be sensitive to quantization, we keep it in default precision - model = quantize(model, skip=("lm_head", "output")) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint) - else: - with fabric.device: - model = LLaMA.from_name(model_size) - checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint) + print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) model.eval() @@ -133,6 +123,7 @@ def main( L.seed_everything(1234) t0 = time.time() + for _ in range(num_samples): y = generate( model, @@ -144,8 +135,9 @@ def main( )[0] # unpack batch dimension print(tokenizer.decode(y)) - print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr) - print(f"Memory used (GB): {torch.cuda.max_memory_reserved() / 1e9:.02f}", file=sys.stderr) + t = time.time() - t0 + print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr) + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) if __name__ == "__main__": diff --git a/lit_llama/__init__.py b/lit_llama/__init__.py index c169d4c..c3395f8 100644 --- a/lit_llama/__init__.py +++ b/lit_llama/__init__.py @@ -1,2 +1,3 @@ from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope +from lit_llama.quantization import as_8_bit_quantized from lit_llama.tokenizer import Tokenizer diff --git a/lit_llama/model.py b/lit_llama/model.py index 3710086..16c479f 100644 --- a/lit_llama/model.py +++ b/lit_llama/model.py @@ -184,9 +184,6 @@ def __init__(self, config: LLaMAConfig) -> None: ) ) - # init all weights - self.apply(self._init_weights) - def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) diff --git a/lit_llama/quantization.py b/lit_llama/quantization.py index a8eb86a..c006316 100644 --- a/lit_llama/quantization.py +++ b/lit_llama/quantization.py @@ -1,19 +1,73 @@ import os -from typing import Tuple +from contextlib import contextmanager +import warnings -import torch.nn as nn +import torch +# configuration for bitsandbytes before import os.environ["BITSANDBYTES_NOWELCOME"] = "1" +warnings.filterwarnings( + "ignore", + message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization" +) +warnings.filterwarnings( + "ignore", + message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable." +) import bitsandbytes as bnb # noqa: E402 -def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ()) -> nn.Module: - for name, module in model.named_children(): - if isinstance(module, nn.Linear) and name not in skip: - model._modules[name] = bnb.nn.Linear8bitLt( - module.in_features, module.out_features, bias=module.bias, has_fp16_weights=False, threshold=threshold - ) +class Linear8bitLt(bnb.nn.Linear8bitLt): + """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and + re-quantizaton when loading the state dict. + + + This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0) + # We quantize the initial weight here so we don't end up filling the device + # memory with float32 weights which could lead to OOM. + self._quantize_weight(self.weight.data) - if module.children(): - quantize(module, threshold=threshold, skip=skip) - return model + def _load_from_state_dict(self, local_state_dict, *args, **kwargs): + # There is only one key that ends with `*.weight`, the other one is the bias + weight_key = next(name for name in local_state_dict.keys() if name.endswith("weight")) + + # Load the weight from the state dict and re-quantize it + weight = local_state_dict.pop(weight_key) + self._quantize_weight(weight) + + # If there is a bias, let nn.Module load it + if local_state_dict: + super()._load_from_state_dict(local_state_dict, *args, **kwargs) + + def _quantize_weight(self, weight: torch.Tensor) -> None: + # This code is taken and adapted from `bnb.nn.Int8Params.cuda()` + B = weight.contiguous().half().cuda() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.weight.data = CB + setattr(self.weight, "CB", CB) + setattr(self.weight, "SCB", SCB) + + +@contextmanager +def as_8_bit_quantized(device: torch.device, enabled: bool = True): + """A context manager under which you can instantiate the model with 8-bit quantized tensors + being created directly on the given device. + """ + + with torch.device(device): + if not enabled: + yield + return + + if device.type != "cuda": + raise ValueError("Quantization is only supported on the GPU.") + + torch_linear_cls = torch.nn.Linear + torch.nn.Linear = Linear8bitLt + yield + torch.nn.Linear = torch_linear_cls diff --git a/tests/test_model.py b/tests/test_model.py index 5311611..8bd3440 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -52,6 +52,7 @@ def test_to_orig_llama(lit_llama, orig_llama) -> None: ) llama_model = lit_llama.LLaMA(llama_config) + llama_model.apply(llama_model._init_weights) orig_llama_model = orig_llama.Transformer(orig_llama_config) copy_weights(llama_model, orig_llama_model)