diff --git a/.github/workflows/et.yml b/.github/workflows/et.yml index 37227f8b2..f61507e4b 100644 --- a/.github/workflows/et.yml +++ b/.github/workflows/et.yml @@ -60,6 +60,13 @@ jobs: wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin popd + + mkdir gguf_files + export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf + export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model + wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true" + wget -O ${GGUF_TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + - name: Run inference run: | export MODEL_PATH=${PWD}/checkpoints/stories15M/stories15M.pt @@ -75,7 +82,7 @@ jobs: echo "Tests complete." - name: Run inference - run: | + run: | export MODEL_PATH=checkpoints/stories15M/stories15M.pt export MODEL_NAME=stories15M export MODEL_DIR=/tmp @@ -121,3 +128,13 @@ jobs: echo "tests complete" echo "******************************************" + - name: Run GGUF export + inference + run: | + export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf + export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model + + python torchchat.py export --gguf-path ${GGUF_PATH} --output-pte-path ${PWD}/${MODEL_NAME}.pte + python torchchat.py generate --gguf-path ${GGUF_PATH} --pte-path ${PWD}/${MODEL_NAME}.pte --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --max-new-tokens 20 > ${PWD}/output_et + cat ${PWD}/output_et + + echo "Tests complete." diff --git a/build/builder.py b/build/builder.py index ac50c6c16..10d6c3717 100644 --- a/build/builder.py +++ b/build/builder.py @@ -9,7 +9,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch._dynamo.config @@ -29,6 +29,7 @@ class BuilderArgs: params_path: Optional[Union[Path, str]] = None params_table: Optional[str] = None gguf_path: Optional[Union[Path, str]] = None + gguf_kwargs: Optional[dict[str, Any]] = None dso_path: Optional[Union[Path, str]] = None pte_path: Optional[Union[Path, str]] = None device: str = "cpu" @@ -91,6 +92,7 @@ def from_args(cls, args): # -> BuilderArgs: params_path=args.params_path, params_table=args.params_table, gguf_path=args.gguf_path, + gguf_kwargs=None, dso_path=args.dso_path, pte_path=args.pte_path, device=args.device, @@ -174,9 +176,30 @@ def device_sync(device): sys.path.append(str(wd)) +# TODO: remove these once ET supports _weight_int4pack_mm +def _set_gguf_kwargs(builder_args, is_et, context: str): + assert context in ["export", "generate"] + assert builder_args.gguf_kwargs is None + + if builder_args.gguf_path is None: + print("No gguf_path provided, so ignoring set_gguf_kwargs.") + return + + builder_args.gguf_kwargs = {} + if is_et: + builder_args.gguf_kwargs["load_as_quantized"] = False + +def _unset_gguf_kwargs(builder_args): + builder_args.gguf_kwargs = None + + def _load_model_gguf(builder_args): assert builder_args.gguf_path - model = Transformer.from_gguf(builder_args.gguf_path) + if builder_args.gguf_kwargs is None: + kwargs = {} + else: + kwargs = builder_args.gguf_kwargs + model = Transformer.from_gguf(builder_args.gguf_path, **kwargs) return model @@ -254,6 +277,15 @@ def _initialize_model( ): print("Loading model ...") t0 = time.time() + + if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): + print("Setting gguf_kwargs for generate.") + is_dso = builder_args.dso_path is not None + is_pte = builder_args.pte_path is not None + assert not (is_dso and is_pte) + assert builder_args.gguf_kwargs is None + _set_gguf_kwargs(builder_args, is_et=is_pte, context="generate") + model_ = _load_model(builder_args) device_sync(device=builder_args.device) print(f"Time to load model: {time.time() - t0:.02f} seconds") diff --git a/build/gguf_loader.py b/build/gguf_loader.py index f98e326da..43603d2a7 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -139,7 +139,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: return model -def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module: +def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module: """ Parses the GGUF file and returns an nn.Module on meta device along with a state_dict that can be loaded into it. @@ -174,14 +174,14 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_ in_features = mod.in_features assert all(t.shape == (in_features, out_features)) - q, s, z = Q4_0.unpack(t) - scales_and_zeros = pack_scales_and_zeros(s, z) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q, inner_k_tiles - ) - - state_dict[f"{fqn}.weight"] = weight_int4pack - state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros + if load_state_dict: + q, s, z = Q4_0.unpack(t) + scales_and_zeros = pack_scales_and_zeros(s, z) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + q, inner_k_tiles + ) + state_dict[f"{fqn}.weight"] = weight_int4pack + state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros parent = _fqn_lookup(_fqn_up(fqn), model) setattr( @@ -197,8 +197,10 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_ ), ) else: - state_dict[f"{fqn}.weight"] = to_float(t) + if load_state_dict: + state_dict[f"{fqn}.weight"] = to_float(t) + assert (state_dict == {}) == (not load_state_dict) return model, state_dict diff --git a/build/model.py b/build/model.py index c3f0af512..ae6038707 100644 --- a/build/model.py +++ b/build/model.py @@ -246,10 +246,11 @@ def from_params(cls, params_path: str): return cls(ModelArgs.from_params(params_path)) @classmethod - def from_gguf(cls, gguf_path: str): + def from_gguf(cls, gguf_path: str, **kwargs): from build.gguf_loader import load_model_and_state_dict - model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8) - model.load_state_dict(state_dict, assign=True) + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) + if state_dict != {}: + model.load_state_dict(state_dict, assign=True) return model diff --git a/export.py b/export.py index 45c61dbb4..3723e4b65 100644 --- a/export.py +++ b/export.py @@ -9,7 +9,7 @@ import torch -from build.builder import _initialize_model, BuilderArgs +from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs from cli import add_arguments_for_export, arg_init, check_args from export_aoti import export_model as export_model_aoti @@ -42,24 +42,48 @@ def main(args): print(f"Using device={builder_args.device}") set_precision(builder_args.precision) + builder_args.dso_path = None builder_args.pte_path = None builder_args.setup_caches = True - model = _initialize_model( - builder_args, - quantize, - ) output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + # TODO: clean this up + # This mess is because ET does not support _weight_int4pack_mm right now + if not builder_args.gguf_path: + model = _initialize_model( + builder_args, + quantize, + ) + model_to_pte = model + model_to_dso = model + else: + if output_pte_path: + _set_gguf_kwargs(builder_args, is_et=True, context="export") + model_to_pte = _initialize_model( + builder_args, + quantize, + ) + _unset_gguf_kwargs(builder_args) + + if output_dso_path: + _set_gguf_kwargs(builder_args, is_et=False, context="export") + model_to_dso = _initialize_model( + builder_args, + quantize, + ) + _unset_gguf_kwargs(builder_args) + + with torch.no_grad(): if output_pte_path: output_pte_path = str(os.path.abspath(output_pte_path)) print(f">{output_pte_path}<") if executorch_export_available: print(f"Exporting model using Executorch to {output_pte_path}") - export_model_et(model, builder_args.device, args.output_pte_path, args) + export_model_et(model_to_pte, builder_args.device, args.output_pte_path, args) else: print( "Export with executorch requested but Executorch could not be loaded" @@ -68,7 +92,7 @@ def main(args): if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) print(f"Exporting model using AOT Inductor to {output_dso_path}") - export_model_aoti(model, builder_args.device, output_dso_path, args) + export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args) if __name__ == "__main__":