Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gguf ci #199

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions .github/workflows/compile-gguf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
run-tinystories:
strategy:
matrix:
runner: [ubuntu-latest, macos-14]
runner: [macos-14]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
Expand Down Expand Up @@ -40,41 +40,47 @@ jobs:
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true"
wget -O ${TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
- name: Run inference
run: |
run: |
export GGUF_PATH=gguf_files/llama-2-7b.Q4_0.gguf
export TOKENIZER_PATH=gguf_files/tokenizer.model
export MODEL_NAME=llama-2-7b.Q4_0.gguf
export MODEL_DIR=/tmp
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager

echo "******************************************"
echo "******* Embed: not quantized *************"
echo "******************************************"

echo "Running eager"
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled

echo "Running compiled"
python generate.py --compile --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti

echo "******************************************"
echo "******* Emb: channel-wise quantized ******"
echo "******************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager

echo "Running eager"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled

echo "Running compiled"
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti

echo "******************************************"
echo "******** Emb: group-wise quantized *******"
echo "******************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_eager

echo "Running eager"
python generate.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 > ./output_compiled

echo "Running compiled"
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --gguf-path ${GGUF_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti

echo "tests complete"
echo "******************************************"

39 changes: 25 additions & 14 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __post_init__(self):
(self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")

if (self.dso_path and self.pte_path):
raise RuntimeError("specify either DSO path or PTE path, but not both")

Expand All @@ -58,7 +58,7 @@ def __post_init__(self):
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")


@classmethod
def from_args(cls, args): # -> BuilderArgs:
return cls(
Expand All @@ -79,14 +79,14 @@ def from_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
speculative_builder_args.checkpoint_dir = None
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
speculative_builder_args.gguf_path = None
speculative_builder_args.dso_path = None
speculative_builder_args.pte_path = None
return speculative_builder_args


@dataclass
class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
Expand All @@ -97,7 +97,7 @@ class TokenizerArgs:
def from_args(cls, args): # -> TokenizerArgs:
is_SentencePiece = True
is_TikToken = False

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.checkpoint_path:
Expand All @@ -106,7 +106,7 @@ def from_args(cls, args): # -> TokenizerArgs:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

Expand All @@ -127,7 +127,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
raise RuntimeError("TikToken not implemented yet!")
else:
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")


def device_sync(device):
if "cuda" in device:
Expand All @@ -147,17 +147,30 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

def _load_model(
def _load_model(builder_args):
if builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)

# TODO: to take advantage of mmap, maybe we write converted gguf to file
# and read back in?
# TODO: should we add check that builder_args.precision is aligned with quant scheme, e.g., bfloat16
# is needed for int4
model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
else:
return _load_model_not_gguf(builder_args)

def _load_model_not_gguf(
builder_args
):
assert not builder_args.gguf_path

use_cuda = "cuda" in builder_args.device
with torch.device("meta"):
if builder_args.params_path:
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_path)
elif builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
else:
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

Expand All @@ -176,7 +189,7 @@ def _load_model(
mmap=True,
)
)

checkpoint = {}
for key in cps[0].keys():
if not torch.allclose(cps[0][key], cps[1][key]):
Expand Down Expand Up @@ -210,7 +223,7 @@ def _initialize_model(
quantize,
):
print("Loading model ...")
t0 = time.time()
t0 = time.time()
model_ = _load_model(
builder_args
)
Expand Down Expand Up @@ -261,5 +274,3 @@ def _initialize_model(
model.to(dtype=builder_args.precision)

return model


4 changes: 1 addition & 3 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from pathlib import Path
from typing import Any, Mapping, Dict
import logging
from quantize import (
WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams
)
from quantize import WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams
from build.gguf_util import F16, F32, Q4_0, Q6_K
import gguf

Expand Down
18 changes: 9 additions & 9 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_args(cls, args): # -> GeneratorArgs:
speculate_k = args.speculate_k,
)


def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
Expand Down Expand Up @@ -305,7 +305,7 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
def _main(
builder_args: BuilderArgs,
speculative_builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
tokenizer_args: TokenizerArgs,
prompt: str = "Hello, my name is",
chat_mode: bool = False,
num_samples: int = 5,
Expand All @@ -332,25 +332,25 @@ def _main(
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

is_chat = "chat" in str(builder_args.checkpoint_path)
if is_chat:
raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. yuck!")

tokenizer = _initialize_tokenizer(tokenizer_args)

builder_args.setup_caches = False
model = _initialize_model(
builder_args,
quantize
)

# will add a version of _initialize_model in future
# (need additional args)
if is_speculative:
from builder import _load_model
speculative_builder_args = builder_args

draft_model = _load_model(
speculative_builder_args,
)
Expand Down Expand Up @@ -478,13 +478,13 @@ def callback(x):
)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)

_main(
builder_args,
speculative_builder_args,
Expand Down
Loading
Loading