diff --git a/.github/workflows/compile_t4-dtype.yml b/.github/workflows/compile_t4-dtype.yml new file mode 100644 index 000000000..e0e3259c2 --- /dev/null +++ b/.github/workflows/compile_t4-dtype.yml @@ -0,0 +1,115 @@ +name: Run compile tests + +on: + pull_request: + push: + branches: + - main + workflow_dispatch: + +jobs: + test-cuda: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.1" + script: | + echo "::group::Print machine info" + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + echo "::endgroup::" + + echo "::group::Download checkpoints" + # Install requirements + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install -r requirements.txt + echo "::endgroup::" + + echo "::group::Download checkpoints" + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_NAME=stories15M + export MODEL_DIR=/tmp + + for DTYPE in bfloat16 float16 float32; do + + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "******************************************" + echo "******* Emb: channel-wise quantized ******" + echo "******************************************" + python generate.py --dtype ${DTYPE} --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "******************************************" + echo "******** Emb: group-wise quantized *******" + echo "******************************************" + python generate.py --dtype ${DTYPE} --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --quant '{" embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "******************************************" + echo "******* INT8 channel-wise quantized ******" + echo "******************************************" + python generate.py --dtype ${DTYPE} --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --quant '{" linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "******************************************" + echo "******** INT8 group-wise quantized *******" + echo "******************************************" + python generate.py --dtype ${DTYPE} --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" + python generate.py --dtype ${DTYPE} --device cuda --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --dtype ${DTYPE} --device cuda --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --dtype ${DTYPE} --device cuda --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --dtype ${DTYPE} --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + done + + echo "tests complete" + echo "******************************************" + echo "::endgroup::" + diff --git a/.github/workflows/compile_t4.yml b/.github/workflows/compile_t4.yml index e96d42fba..9815f46a1 100644 --- a/.github/workflows/compile_t4.yml +++ b/.github/workflows/compile_t4.yml @@ -93,13 +93,18 @@ jobs: python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti cat ./output_aoti + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" + python generate.py --device cuda --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --device cuda --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --device cuda --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + echo "tests complete" echo "******************************************" echo "::endgroup::" - # echo "********* EAGER vs TORCH.COMPILE *********" - # echo "******************************************" - # diff output_eager output_compiled - # echo "******************************************" - # echo "********* EAGER vs AOT INDUCTOR *********" - # echo "******************************************" - # diff output_eager output_aoti + diff --git a/.github/workflows/test_mps-dtype.yml b/.github/workflows/test_mps-dtype.yml index 78cb1f789..ee474ee7e 100644 --- a/.github/workflows/test_mps-dtype.yml +++ b/.github/workflows/test_mps-dtype.yml @@ -52,14 +52,19 @@ jobs: python generate.py --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager - # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager + + python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + + python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + + python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + + python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + + PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager done \ No newline at end of file diff --git a/.github/workflows/test_mps.yml b/.github/workflows/test_mps.yml index 736433b1d..0c0e0bb94 100644 --- a/.github/workflows/test_mps.yml +++ b/.github/workflows/test_mps.yml @@ -71,6 +71,6 @@ jobs: echo "*** linear int4" echo "************************************************************" - # PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager + PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager \ No newline at end of file diff --git a/build/builder.py b/build/builder.py index 26dd3cc65..f62c42131 100644 --- a/build/builder.py +++ b/build/builder.py @@ -165,7 +165,6 @@ def _load_model_not_gguf( ): assert not builder_args.gguf_path - use_cuda = "cuda" in builder_args.device with torch.device("meta"): if builder_args.params_path: model = Transformer.from_params(builder_args.params_path) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 06396b50a..a52c15274 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -177,11 +177,12 @@ def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], parent, _fqn_last(fqn), WeightOnlyInt4Linear( - in_features, out_features, + "cpu", # TODO: should --device work for gguf load? (yes?!) + in_features, + out_features, bias=False, groupsize=Q4_0.groupsize, inner_k_tiles=inner_k_tiles, - use_cuda=False ) ) else: diff --git a/quantize.py b/quantize.py index f6b1220d4..4890c4b42 100644 --- a/quantize.py +++ b/quantize.py @@ -699,6 +699,13 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1): def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) + + # avoid errors in MPSaround bfloat16 until int4pack_mm is in nightlies + # print("MPS workaround active, will produce bogus results") + if "mps" in str(x.device): + new_shape = origin_x_size[:-1] + (out_features,) + return torch.zeros(new_shape, dtype=x.dtype, device=x.device) + c = torch.ops.aten._weight_int4pack_mm( x.to(torch.bfloat16), # TODO: should probably make a warning if x is not already bfloat16 weight_int4pack, @@ -713,22 +720,37 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou def _int4_check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=False): +def replace_linear_int4( + module, + device, + groupsize, + inner_k_tiles, + padding_allowed, +): for name, child in module.named_children(): if isinstance(child, nn.Linear): if _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + setattr( + module, + name, + WeightOnlyInt4Linear( + device, + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, )) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) + replace_linear_int4( + child, device, groupsize, inner_k_tiles, padding_allowed + ) class WeightOnlyInt4QuantHandler(QuantHandler): def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True): self.mod = mod - self.device = device, + self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding_allowed = padding_allowed @@ -761,14 +783,16 @@ def create_quantized_state_dict(self): weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros( weight.to(torch.float), self.groupsize, self.inner_k_tiles ) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + weight_int4pack = weight_int4pack.to(device=self.device) + scales_and_zeros = scales_and_zeros.to(device=self.device) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros return cur_state_dict - def convert_for_runtime(self, use_cuda=False): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed) return self.mod def quantized_model(self) -> nn.Module: @@ -785,8 +809,14 @@ class WeightOnlyInt4Linear(torch.nn.Module): weight: torch.Tensor def __init__( - self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + self, + device: str, + in_features: int, + out_features: int, + bias=True, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, ) -> None: super().__init__() self.padding = not _int4_check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -805,12 +835,20 @@ def __init__( assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty( + (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), + dtype=torch.int32, + device=device, + ) ) # MKG: torch.float self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=get_precision()) + torch.empty( + (in_features // groupsize, out_features, 2), + dtype=get_precision(), + device=device, + ) ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -821,7 +859,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize + self.weight, + self.scales_and_zeros, + self.out_features, + self.groupsize ) ######################################################################### @@ -1261,8 +1302,8 @@ def make_names_and_values_dict_func(q, qparams): super().__init__() - def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding) return self.mod def quantized_model(self) -> nn.Module: @@ -1350,6 +1391,7 @@ def quantized_model(self) -> nn.Module: class WeightOnlyInt4HqqQuantHandler: def __init__(self, mod, device, *, groupsize): self.mod = mod + self.device = device self.groupsize = groupsize def create_quantized_state_dict(self): @@ -1373,15 +1415,15 @@ def create_quantized_state_dict(self): # we use Int4 packaged in an int8 for now, packing to follow # return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() return WeightOnlyInt8QuantHandler( - self.mod, bitwidth=4, groupsize=self.groupsize + self.mod, self.device, bitwidth=4, groupsize=self.groupsize ).create_quantized_state_dict() def convert_for_runtime(self): # we use Int4 packaged in an int8 for now, packing to follow # ALSO: all code must work for CPU, CUDA, MPS - # return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) + # return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime() return WeightOnlyInt4GPTQQuantHandler( - self.mod, bitwidth=4, groupsize=self.groupsize + self.mod, self.device, bitwidth=4, groupsize=self.groupsize ).convert_for_runtime() def quantized_model(self) -> nn.Module: