Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some flag cleanup and centralization #860

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions sharktank/sharktank/evaluate/perplexity_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def __init__(
self,
device,
kv_cache_type,
activation_dtype=torch.float32,
attention_dtype=torch.float32,
):
self.device = device
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
self.attention_dtype = torch.float32
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -321,43 +323,34 @@ def run_perplexity_torch(

def main(argv):
parser = cli.create_parser()
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
)

parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test",
)

cli.add_model_options(parser)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser, args=argv)

device = torch.device(args.device) if args.device else None
kv_cache_type = args.kv_cache_type
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
# Override flag if dataset disagrees
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else args.tensor_parallelism_size
)

ppl = run_perplexity_torch(
dataset=dataset,
tokenizer=tokenizer,
device=device,
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
)
Expand Down
7 changes: 4 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@ def main():
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else 1
else args.tensor_parallelism_size
)

llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
use_hf=args.use_hf,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
)
llama_config.fake_quant = args.fake_quant

Expand Down
19 changes: 2 additions & 17 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,10 @@ def main():

parser = cli.create_parser()
parser.add_argument("prompt", nargs="+", help="Prompt strings")
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
parser.add_argument(
"--save_intermediates_path",
help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors",
)
parser.add_argument(
"--activation-dtype",
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
Expand All @@ -262,18 +250,15 @@ def main():
cli.add_model_options(parser)
args = cli.parse(parser)
device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
assert isinstance(activation_dtype, torch.dtype)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
activation_dtype=args.activation_dtype,
attention_dtype=args.activation_dtype,
attention_kernel=args.attention_kernel,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
Expand Down
31 changes: 30 additions & 1 deletion sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import argparse
import logging
from pathlib import Path

import torch
from ..types import Dataset

from . import hf_datasets
Expand All @@ -31,6 +31,12 @@ def create_parser(
def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None):
"""Parses arguments and does any prescribed global process setup."""
parsed_args = parser.parse_args(args)
# Set torch dtypes
for attr in ["activation_dtype", "attention_dtype"]:
if hasattr(parsed_args, attr):
dtype = getattr(torch, getattr(parsed_args, attr))
assert isinstance(dtype, torch.dtype)
setattr(parsed_args, attr, dtype)
return parsed_args


Expand Down Expand Up @@ -79,6 +85,29 @@ def add_model_options(parser: argparse.ArgumentParser):
help="Skips export decode",
action="store_true",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
parser.add_argument(
"--activation-dtype",
help="DType to use for activations in the model",
default="float16",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for activations in the model",
default="float16",
)
parser.add_argument("--device", help="Torch device (or default)")

parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding. Will be overridden by dataset.properties if present",
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
Loading