Skip to content

Commit

Permalink
Fix gguf ci (#199)
Browse files Browse the repository at this point in the history
* load quantized gguf

* add comments

* remove AOTI

* remove ubunut
  • Loading branch information
metascroy authored and malfet committed Jul 17, 2024
1 parent da28543 commit f17d2dd
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 80 deletions.
42 changes: 24 additions & 18 deletions .github/workflows/compile-gguf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
run-tinystories:
strategy:
matrix:
runner: [ubuntu-latest, macos-14]
runner: [macos-14]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
Expand Down Expand Up @@ -40,41 +40,47 @@ jobs:
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true"
wget -O ${TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
- name: Run inference
run: |
run: |
export GGUF_PATH=gguf_files/llama-2-7b.Q4_0.gguf
export TOKENIZER_PATH=gguf_files/tokenizer.model
export MODEL_NAME=llama-2-7b.Q4_0.gguf
export MODEL_DIR=/tmp
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager
echo "******************************************"
echo "******* Embed: not quantized *************"
echo "******************************************"
echo "Running eager"
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled
echo "Running compiled"
python generate.py --compile --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******* Emb: channel-wise quantized ******"
echo "******************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager
echo "Running eager"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled
echo "Running compiled"
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******** Emb: group-wise quantized *******"
echo "******************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager
echo "Running eager"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled
echo "Running compiled"
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "tests complete"
echo "******************************************"
39 changes: 25 additions & 14 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __post_init__(self):
(self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")

if (self.dso_path and self.pte_path):
raise RuntimeError("specify either DSO path or PTE path, but not both")

Expand All @@ -58,7 +58,7 @@ def __post_init__(self):
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")


@classmethod
def from_args(cls, args): # -> BuilderArgs:
return cls(
Expand All @@ -79,14 +79,14 @@ def from_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
speculative_builder_args.checkpoint_dir = None
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
speculative_builder_args.gguf_path = None
speculative_builder_args.dso_path = None
speculative_builder_args.pte_path = None
return speculative_builder_args


@dataclass
class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
Expand All @@ -97,7 +97,7 @@ class TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.checkpoint_path:
Expand All @@ -106,7 +106,7 @@ def from_args(cls, args): # -> TokenizerArgs:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

Expand All @@ -127,7 +127,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
raise RuntimeError("TikToken not implemented yet!")
else:
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")


def device_sync(device):
if "cuda" in device:
Expand All @@ -147,17 +147,30 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

def _load_model(
def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)

# TODO: to take advantage of mmap, maybe we write converted gguf to file
# and read back in?
# TODO: should we add check that builder_args.precision is aligned with quant scheme, e.g., bfloat16
# is needed for int4
model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
else:
return _load_model_not_gguf(builder_args)

def _load_model_not_gguf(
builder_args
):
assert not builder_args.gguf_path

use_cuda = "cuda" in builder_args.device
with torch.device("meta"):
if builder_args.params_path:
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_path)
elif builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
else:
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

Expand All @@ -176,7 +189,7 @@ def _load_model(
mmap=True,
)
)

checkpoint = {}
for key in cps[0].keys():
if not torch.allclose(cps[0][key], cps[1][key]):
Expand Down Expand Up @@ -210,7 +223,7 @@ def _initialize_model(
quantize,
):
print("Loading model ...")
t0 = time.time()
t0 = time.time()
model_ = _load_model(
builder_args
)
Expand Down Expand Up @@ -261,5 +274,3 @@ def _initialize_model(
model.to(dtype=builder_args.precision)

return model


4 changes: 1 addition & 3 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from pathlib import Path
from typing import Any, Mapping, Dict
import logging
from quantize import (
WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams
)
from quantize import WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams
from build.gguf_util import F16, F32, Q4_0, Q6_K
import gguf

Expand Down
18 changes: 9 additions & 9 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_args(cls, args): # -> GeneratorArgs:
speculate_k = args.speculate_k,
)


def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
Expand Down Expand Up @@ -305,7 +305,7 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
def _main(
builder_args: BuilderArgs,
speculative_builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
tokenizer_args: TokenizerArgs,
prompt: str = "Hello, my name is",
chat_mode: bool = False,
num_samples: int = 5,
Expand All @@ -332,25 +332,25 @@ def _main(
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

is_chat = "chat" in str(builder_args.checkpoint_path)
if is_chat:
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. yuck!")

tokenizer = _initialize_tokenizer(tokenizer_args)

builder_args.setup_caches = False
model = _initialize_model(
builder_args,
quantize
)

# will add a version of _initialize_model in future
# (need additional args)
if is_speculative:
from builder import _load_model
speculative_builder_args = builder_args

draft_model = _load_model(
speculative_builder_args,
)
Expand Down Expand Up @@ -478,13 +478,13 @@ def callback(x):
)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)

_main(
builder_args,
speculative_builder_args,
Expand Down
Loading

0 comments on commit f17d2dd

Please sign in to comment.