Skip to content

Commit

Permalink
Using torchao for quantization
Browse files Browse the repository at this point in the history
Summary:
This PR added int4, int4-gptq, int8, 8da4w quantization to llama-fast

Test Plan:
Note: int4wo, int4wo-gptq only works for llama2 7b right now since stories has dimension 288 which needs some padding for it to work with int4 kernels

python generate.py --compile --checkpoint-path ${MODEL_PATH} --prompt "Hello, my name is" --device cuda --dtype bf16
python generate.py --compile --checkpoint-path ${MODEL_PATH} --prompt "Hello, my name is" --device cuda --quantize '{"linear:int8": {}}'
python generate.py --compile --checkpoint-path ${MODEL_PATH} --prompt "Hello, my name is" --device cuda --quantize '{"linear:int4": {"groupsize": 32}}' --dtype bf16
python generate.py --compile --checkpoint-path ${MODEL_PATH} --prompt "Hello, my name is" --device cuda --quantize '{"linear:int4-gptq": {"groupsize": 32}}' --dtype bf16
python generate.py --compile --checkpoint-path ${MODEL_PATH} --prompt "Hello, my name is" --device cuda --quantize '{"linear:8da4w": {"groupsize": 32}}' --dtype fp32

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 11, 2024
1 parent 49bb02d commit 82a0adc
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 60 deletions.
8 changes: 4 additions & 4 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, idx, input_pos):
return logits # sample(logits, **sampling_kwargs)


def main(checkpoint_path, device, quantize = "{ }", args = None):
def main(checkpoint_path, device, args = None):
assert checkpoint_path.is_file(), checkpoint_path

print(f"Using device={device}")
Expand All @@ -72,12 +72,12 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

quantize_model(model, args.quantize)
quantize_model(model, args)

# dtype:
if args.dtype:
model.to(dtype=name_to_dtype(args.dtype))

model = model_wrapper(model, device=device)

output_pte_path = args.output_pte_path
Expand Down Expand Up @@ -180,7 +180,7 @@ def cli():


args = parser.parse_args()
main(args.checkpoint_path, args.device, args.quantize, args)
main(args.checkpoint_path, args.device, args)

if __name__ == "__main__":
cli()
26 changes: 16 additions & 10 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def generate(
draft_model: Transformer,
speculate_k: Optional[int] = 8,
callback=lambda x: x,
precision=torch.float,
**sampling_kwargs,
) -> torch.Tensor:
"""
Expand All @@ -214,13 +215,14 @@ def generate(
max_seq_length = min(T_new, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
model = model.to(device)
max_seq_length = (
max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
)
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length, dtype=precision)
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length, dtype=precision)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
Expand Down Expand Up @@ -315,8 +317,8 @@ def main(
device="cuda",
dso_path=None,
pte_path=None,
quantize=None,
model_dtype=None,
args=None,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
assert (
Expand Down Expand Up @@ -344,7 +346,7 @@ def main(
# print = lambda *args, **kwargs: None

print(f"Using device={device}")
precision = torch.float # bfloat16
precision = torch.float
is_speculative = draft_checkpoint_path is not None
is_chat = "chat" in str(checkpoint_path)

Expand Down Expand Up @@ -377,17 +379,20 @@ def main(
model = model_

# Add new CLI arg
if quantize:
if args.quantize:
with torch.device(device):
# TODO: fix max_seq_length
model.setup_caches(max_batch_size=1, max_seq_length=2048, dtype=precision)
device_sync(device=device)
t0q = time.time()
quantize_model(model, quantize)
quantize_model(model, args)
device_sync(device=device) # MKG
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

# dtype:
if model_dtype:
model.to(dtype=name_to_dtype(model_dtype))
model = model.to(dtype=name_to_dtype(model_dtype))

if is_speculative:
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
else:
Expand Down Expand Up @@ -480,6 +485,7 @@ def callback(x):
callback=callback,
temperature=temperature,
top_k=top_k,
precision=name_to_dtype(model_dtype),
)
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
if i == -1:
Expand Down Expand Up @@ -610,7 +616,7 @@ def cli():
args = parser.parse_args()

if args.seed:
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)

main(
args.prompt,
Expand All @@ -629,8 +635,8 @@ def cli():
args.device,
args.dso_path,
args.pte_path,
args.quantize,
args.dtype,
args,
)

if __name__ == "__main__":
Expand Down
32 changes: 29 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@
from torch.nn import functional as F


def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
if inps.dim() != 2:
raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}")

inps = inps.squeeze(0)
# setup inputs in correct format
T = inps.size(0)
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device)
seq[:T] = inps
input_pos = torch.arange(0, T, device=inps.device)
x = seq.index_select(0, input_pos).view(1, -1)
return (x, input_pos)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
Expand Down Expand Up @@ -134,7 +150,7 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float):
if (
self.max_seq_length >= max_seq_length
and self.max_batch_size >= max_batch_size
Expand All @@ -146,7 +162,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype
)

freqs_cis = precompute_freqs_cis(
Expand Down Expand Up @@ -218,7 +234,7 @@ def __init__(self, config: ModelArgs):
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
# self._register_load_state_dict_pre_hook(self.load_hook)
self._register_load_state_dict_pre_hook(self.load_hook)

# def load_hook(self, state_dict, prefix, *args):
# if prefix + "wq.weight" in state_dict:
Expand All @@ -227,6 +243,16 @@ def __init__(self, config: ModelArgs):
# wv = state_dict.pop(prefix + "wv.weight")
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def load_hook(self, state_dict, prefix, *args):
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
q_size = self.n_head * self.head_dim
kv_size = self.n_local_heads * self.head_dim
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv

def forward(
self,
x: Tensor,
Expand Down
Loading

0 comments on commit 82a0adc

Please sign in to comment.