diff --git a/.github/workflows/compile_t4.yml b/.github/workflows/compile_t4.yml index 0ac41f543..65f795a71 100644 --- a/.github/workflows/compile_t4.yml +++ b/.github/workflows/compile_t4.yml @@ -52,13 +52,13 @@ jobs: echo "******************************************" echo "******* Emb: channel-wise quantized ******" echo "******************************************" - # python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled - # cat ./output_compiled - # python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so - # python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti - # cat ./output_aoti + python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti echo "******************************************" echo "******** Emb: group-wise quantized *******" diff --git a/.github/workflows/test_mps-dtype.yml b/.github/workflows/test_mps-dtype.yml index 482c3c2ec..78cb1f789 100644 --- a/.github/workflows/test_mps-dtype.yml +++ b/.github/workflows/test_mps-dtype.yml @@ -52,14 +52,14 @@ jobs: python generate.py --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager done \ No newline at end of file diff --git a/.github/workflows/test_mps.yml b/.github/workflows/test_mps.yml index 9b15f4778..f8e166790 100644 --- a/.github/workflows/test_mps.yml +++ b/.github/workflows/test_mps.yml @@ -48,14 +48,14 @@ jobs: python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager - # python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager - # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager # cat ./output_eager \ No newline at end of file diff --git a/build/builder.py b/build/builder.py index 020a8f462..2f758a0a8 100644 --- a/build/builder.py +++ b/build/builder.py @@ -259,7 +259,7 @@ def _initialize_model( if quantize: t0q = time.time() - quantize_model(model, quantize) + quantize_model(model, builder_args.device, quantize) device_sync(device=builder_args.device) print(f"Time to quantize model: {time.time() - t0q:.02f} seconds") diff --git a/generate.py b/generate.py index e22308bc5..ad6085582 100644 --- a/generate.py +++ b/generate.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools import sys +import os import time from pathlib import Path from typing import Optional, Tuple @@ -333,9 +334,9 @@ def _main( set_precision(builder_args.precision) is_speculative = speculative_builder_args.checkpoint_path is not None - is_chat = "chat" in str(builder_args.checkpoint_path) + is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) if is_chat: - raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. yuck!") + raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!") tokenizer = _initialize_tokenizer(tokenizer_args) diff --git a/quantize.py b/quantize.py index f36f1a372..2d596b795 100644 --- a/quantize.py +++ b/quantize.py @@ -55,7 +55,7 @@ def name_to_dtype(name): ########################################################################## ### process quantization dictionary ### -def quantize_model(model: nn.Module, quantize_options): +def quantize_model(model: nn.Module, device, quantize_options): """ Quantize the specified model using the quantizers described by a quantization dict of the form: @@ -74,6 +74,7 @@ def quantize_model(model: nn.Module, quantize_options): if quantizer == "embedding": model = EmbeddingOnlyInt8QuantHandler( model, + device, **q_kwargs ).quantized_model() elif linears_quantized: @@ -82,30 +83,35 @@ def quantize_model(model: nn.Module, quantize_options): linears_quantized = True model = WeightOnlyInt8QuantHandler( model, + device, **q_kwargs ).quantized_model() elif quantizer == "linear:int4": linears_quantized = True model = WeightOnlyInt4QuantHandler( model, + device, **q_kwargs ).quantized_model() elif quantizer == "linear:a8w4dq": linears_quantized = True model = Int8DynActInt4WeightQuantHandler( model, + device, **q_kwargs ).quantized_model() elif quantizer == "linear:gptq": linears_quantized = True model = WeightOnlyInt4GPTQQuantHandler( model, + device, **q_kwargs ).quantized_model() elif quantizer == "linear:hqq": linears_quantized = True model = WeightOnlyInt4HqqQuantHandler( model, + device, **q_kwargs ).quantized_model() elif quantizer == "precision": @@ -371,12 +377,14 @@ class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, mod, + device, *, node_type: str = "*", bitwidth: Optional[int] = None, groupsize: Optional[int] = None, ): self.mod = mod + self.device = device, self.groupsize = groupsize self.node_type = node_type if bitwidth is None: @@ -494,7 +502,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False + module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False ): for name, child in module.named_children(): # print(f"name: {name}") @@ -505,6 +513,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel( module, name, QuantizedGroupEmbedding( + device=device, vocab_size=child.weight.shape[0], embedding_dim=child.weight.shape[1], groupsize=groupsize, @@ -518,10 +527,11 @@ def replace_embedding_weight_only_grouped_int8_per_channel( class EmbeddingOnlyInt8QuantHandler(QuantHandler): - def __init__(self, mod, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False): + def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False): if isinstance(packed, str): packed = (packed == "True") self.mod = mod + self.device = device self.groupsize = groupsize self.bitwidth = bitwidth self.packed = packed @@ -565,7 +575,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: if packed: if weight.shape[-1] %2 != 0: - raise RUntimeError("automatic padding not implemented yet") + raise RuntimeError("automatic padding not implemented yet") weight_range_shifted = weight.add(8).view(torch.uint8) weight_view = weight_range_shifted.view( @@ -578,6 +588,8 @@ def create_quantized_state_dict(self, packed=False) -> Dict: weight_packed = weight_even + weight_odd weight = weight_packed + weight = weight.to(device=self.device) + scales = scales.to(device=self.device) # Update state dict cur_state_dict[f"{fqn}.weight"] = weight # squeeze makes groupsize=rowsize unidimensional @@ -587,7 +599,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: def convert_for_runtime(self) -> nn.Module: replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.bitwidth, self.groupsize, self.packed + self.mod, self.device, self.bitwidth, self.groupsize, self.packed ) return self.mod @@ -601,10 +613,10 @@ def quantized_model(self) -> nn.Module: class QuantizedGroupEmbedding(torch.nn.Module): def __init__( self, + device, vocab_size: int, embedding_dim: int, groupsize: Optional[int] = None, - device=None, dtype=torch.half, packed=False, ) -> None: @@ -616,20 +628,20 @@ def __init__( self.packed = packed if not packed: self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8) + "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device) ) else: # packed self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8) + "weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8, device=device) ) groups_per_row = (embedding_dim + groupsize - 1) // groupsize if groups_per_row > 1: self.register_buffer( - "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16) + "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device) ) else: self.register_buffer( - "scales", torch.ones((vocab_size,), dtype=torch.float16) + "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) ) @torch.no_grad() @@ -712,8 +724,9 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c class WeightOnlyInt4QuantHandler(QuantHandler): - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True): + def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True): self.mod = mod + self.device = device, self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding_allowed = padding_allowed @@ -908,12 +921,15 @@ class Int8DynActInt4WeightQuantHandler(QuantHandler): def __init__( self, mod, + device, + * , groupsize=256, padding_allowed=False, precision=torch.float32, scales_precision=torch.float32, ): self.mod = mod + self.device = device self.groupsize = groupsize self.padding_allowed = padding_allowed self.precision = precision @@ -1209,9 +1225,10 @@ def convert_for_runtime(self) -> "nn.Module": class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding=True): from build.model import find_multiple self.mod = mod + self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding = padding @@ -1329,7 +1346,7 @@ def quantized_model(self) -> nn.Module: ### WIP: HQQ ### class WeightOnlyInt4HqqQuantHandler: - def __init__(self, mod, groupsize): + def __init__(self, mod, device, *, groupsize): self.mod = mod self.groupsize = groupsize