Skip to content

Commit

Permalink
some flag cleanup and centralization (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Jan 23, 2025
1 parent 05a94f9 commit 780bf3f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
31 changes: 12 additions & 19 deletions sharktank/sharktank/evaluate/perplexity_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def __init__(
self,
device,
kv_cache_type,
activation_dtype=torch.float32,
attention_dtype=torch.float32,
):
self.device = device
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
self.attention_dtype = torch.float32
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -321,43 +323,34 @@ def run_perplexity_torch(

def main(argv):
parser = cli.create_parser()
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
)

parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test",
)

cli.add_model_options(parser)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser, args=argv)

device = torch.device(args.device) if args.device else None
kv_cache_type = args.kv_cache_type
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
# Override flag if dataset disagrees
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else args.tensor_parallelism_size
)

ppl = run_perplexity_torch(
dataset=dataset,
tokenizer=tokenizer,
device=device,
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
)
Expand Down
7 changes: 4 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@ def main():
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else 1
else args.tensor_parallelism_size
)

llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
use_hf=args.use_hf,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
activation_dtype=args.activation_dtype,
attention_dtype=args.attention_dtype,
)
llama_config.fake_quant = args.fake_quant

Expand Down
19 changes: 2 additions & 17 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,10 @@ def main():

parser = cli.create_parser()
parser.add_argument("prompt", nargs="+", help="Prompt strings")
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
parser.add_argument(
"--save_intermediates_path",
help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors",
)
parser.add_argument(
"--activation-dtype",
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
Expand All @@ -262,18 +250,15 @@ def main():
cli.add_model_options(parser)
args = cli.parse(parser)
device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
assert isinstance(activation_dtype, torch.dtype)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
activation_dtype=args.activation_dtype,
attention_dtype=args.activation_dtype,
attention_kernel=args.attention_kernel,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
Expand Down
31 changes: 30 additions & 1 deletion sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import argparse
import logging
from pathlib import Path

import torch
from ..types import Dataset

from . import hf_datasets
Expand All @@ -31,6 +31,12 @@ def create_parser(
def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None):
"""Parses arguments and does any prescribed global process setup."""
parsed_args = parser.parse_args(args)
# Set torch dtypes
for attr in ["activation_dtype", "attention_dtype"]:
if hasattr(parsed_args, attr):
dtype = getattr(torch, getattr(parsed_args, attr))
assert isinstance(dtype, torch.dtype)
setattr(parsed_args, attr, dtype)
return parsed_args


Expand Down Expand Up @@ -79,6 +85,29 @@ def add_model_options(parser: argparse.ArgumentParser):
help="Skips export decode",
action="store_true",
)
parser.add_argument(
"--use-hf",
action="store_true",
default=False,
)
parser.add_argument(
"--activation-dtype",
help="DType to use for activations in the model",
default="float16",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for activations in the model",
default="float16",
)
parser.add_argument("--device", help="Torch device (or default)")

parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding. Will be overridden by dataset.properties if present",
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down

0 comments on commit 780bf3f

Please sign in to comment.