Skip to content

Commit

Permalink
Speed up quantization in generate.py (OpenGVLab#35)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Mar 28, 2023
1 parent c409960 commit ccaeeba
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 31 deletions.
26 changes: 9 additions & 17 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import sys
import time
from pathlib import Path
Expand All @@ -7,8 +6,7 @@
import lightning as L
import torch

from lit_llama.model import LLaMA
from lit_llama.tokenizer import Tokenizer
from lit_llama import LLaMA, Tokenizer, as_8_bit_quantized


@torch.no_grad()
Expand Down Expand Up @@ -104,21 +102,13 @@ def main(

fabric = L.Fabric(accelerator=accelerator, devices=1)

if quantize:
from lit_llama.quantization import quantize

print("Running quantization. This may take a minute ...")
# TODO: Initializing the model directly on the device does not work with quantization
with as_8_bit_quantized(fabric.device, enabled=quantize):
print("Loading model ...", file=sys.stderr)
t0 = time.time()
model = LLaMA.from_name(model_size)
# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
else:
with fabric.device:
model = LLaMA.from_name(model_size)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()

Expand All @@ -133,6 +123,7 @@ def main(

L.seed_everything(1234)
t0 = time.time()

for _ in range(num_samples):
y = generate(
model,
Expand All @@ -144,8 +135,9 @@ def main(
)[0] # unpack batch dimension
print(tokenizer.decode(y))

print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr)
print(f"Memory used (GB): {torch.cuda.max_memory_reserved() / 1e9:.02f}", file=sys.stderr)
t = time.time() - t0
print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions lit_llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
from lit_llama.quantization import as_8_bit_quantized
from lit_llama.tokenizer import Tokenizer
3 changes: 0 additions & 3 deletions lit_llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,6 @@ def __init__(self, config: LLaMAConfig) -> None:
)
)

# init all weights
self.apply(self._init_weights)

def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
Expand Down
76 changes: 65 additions & 11 deletions lit_llama/quantization.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,73 @@
import os
from typing import Tuple
from contextlib import contextmanager
import warnings

import torch.nn as nn
import torch

# configuration for bitsandbytes before import
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
warnings.filterwarnings(
"ignore",
message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
)
warnings.filterwarnings(
"ignore",
message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable."
)
import bitsandbytes as bnb # noqa: E402


def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ()) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.Linear) and name not in skip:
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features, module.out_features, bias=module.bias, has_fp16_weights=False, threshold=threshold
)
class Linear8bitLt(bnb.nn.Linear8bitLt):
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
re-quantizaton when loading the state dict.
This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
# We quantize the initial weight here so we don't end up filling the device
# memory with float32 weights which could lead to OOM.
self._quantize_weight(self.weight.data)

if module.children():
quantize(module, threshold=threshold, skip=skip)
return model
def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
# There is only one key that ends with `*.weight`, the other one is the bias
weight_key = next(name for name in local_state_dict.keys() if name.endswith("weight"))

# Load the weight from the state dict and re-quantize it
weight = local_state_dict.pop(weight_key)
self._quantize_weight(weight)

# If there is a bias, let nn.Module load it
if local_state_dict:
super()._load_from_state_dict(local_state_dict, *args, **kwargs)

def _quantize_weight(self, weight: torch.Tensor) -> None:
# This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
B = weight.contiguous().half().cuda()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.weight.data = CB
setattr(self.weight, "CB", CB)
setattr(self.weight, "SCB", SCB)


@contextmanager
def as_8_bit_quantized(device: torch.device, enabled: bool = True):
"""A context manager under which you can instantiate the model with 8-bit quantized tensors
being created directly on the given device.
"""

with torch.device(device):
if not enabled:
yield
return

if device.type != "cuda":
raise ValueError("Quantization is only supported on the GPU.")

torch_linear_cls = torch.nn.Linear
torch.nn.Linear = Linear8bitLt
yield
torch.nn.Linear = torch_linear_cls
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_to_orig_llama(lit_llama, orig_llama) -> None:
)

llama_model = lit_llama.LLaMA(llama_config)
llama_model.apply(llama_model._init_weights)
orig_llama_model = orig_llama.Transformer(orig_llama_config)

copy_weights(llama_model, orig_llama_model)
Expand Down

0 comments on commit ccaeeba

Please sign in to comment.