diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index 258e8c9a0..9ece1e4a0 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -59,11 +59,13 @@ def __init__( self, device, kv_cache_type, + activation_dtype=torch.float32, + attention_dtype=torch.float32, ): self.device = device self.kv_cache_type = kv_cache_type - self.activation_dtype = torch.float32 - self.attention_dtype = torch.float32 + self.activation_dtype = activation_dtype + self.attention_dtype = attention_dtype def timeit(func): def wrapper(*args, **kwargs): @@ -321,21 +323,7 @@ def run_perplexity_torch( def main(argv): parser = cli.create_parser() - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--device", help="Torch device (or default)") - parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch_sdpa"], - ) - parser.add_argument( - "--tensor-parallelism-size", - type=int, - default=1, - help="Number of devices for tensor parallel sharding.", - ) parser.add_argument( "--num-prompts", type=int, @@ -343,21 +331,26 @@ def main(argv): help="Number of prompts for perplexity test", ) + cli.add_model_options(parser) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) args = cli.parse(parser, args=argv) device = torch.device(args.device) if args.device else None - kv_cache_type = args.kv_cache_type dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) + # Override flag if dataset disagrees + tensor_parallelism_size = ( + dataset.properties["tensor_parallelism_size"] + if "tensor_parallelism_size" in dataset.properties + else args.tensor_parallelism_size + ) ppl = run_perplexity_torch( dataset=dataset, tokenizer=tokenizer, device=device, - kv_cache_type=kv_cache_type, - tensor_parallelism_size=args.tensor_parallelism_size, + tensor_parallelism_size=tensor_parallelism_size, attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, ) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 686533ca2..24ec55cf5 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -72,17 +72,18 @@ def main(): tensor_parallelism_size = ( dataset.properties["tensor_parallelism_size"] if "tensor_parallelism_size" in dataset.properties - else 1 + else args.tensor_parallelism_size ) llama_config = LlamaModelConfig( hp, tensor_parallelism_size=tensor_parallelism_size, - use_hf=False, + use_hf=args.use_hf, static_tables=False, # Rely on the compiler for hoisting tables. - kv_cache_type="paged", attention_kernel=args.attention_kernel, block_seq_stride=args.block_seq_stride, + activation_dtype=args.activation_dtype, + attention_dtype=args.attention_dtype, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index b30acc026..768575441 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -234,22 +234,10 @@ def main(): parser = cli.create_parser() parser.add_argument("prompt", nargs="+", help="Prompt strings") - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--device", help="Torch device (or default)") parser.add_argument( "--save_intermediates_path", help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors", ) - parser.add_argument( - "--activation-dtype", - help="DType to use for activations in the model", - default="float32", - ) - parser.add_argument( - "--use-hf", - action="store_true", - default=False, - ) parser.add_argument( "--tensor-parallelism-size", type=int, @@ -262,18 +250,15 @@ def main(): cli.add_model_options(parser) args = cli.parse(parser) device = torch.device(args.device) if args.device else None - activation_dtype = getattr(torch, args.activation_dtype) - assert isinstance(activation_dtype, torch.dtype) dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) prompts = args.prompt config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, - kv_cache_type=args.kv_cache_type, device=device, - activation_dtype=activation_dtype, - attention_dtype=activation_dtype, + activation_dtype=args.activation_dtype, + attention_dtype=args.activation_dtype, attention_kernel=args.attention_kernel, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index cdcdd8c2c..3dc946817 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -11,7 +11,7 @@ import argparse import logging from pathlib import Path - +import torch from ..types import Dataset from . import hf_datasets @@ -31,6 +31,12 @@ def create_parser( def parse(parser: argparse.ArgumentParser, *, args: Sequence[str] | None = None): """Parses arguments and does any prescribed global process setup.""" parsed_args = parser.parse_args(args) + # Set torch dtypes + for attr in ["activation_dtype", "attention_dtype"]: + if hasattr(parsed_args, attr): + dtype = getattr(torch, getattr(parsed_args, attr)) + assert isinstance(dtype, torch.dtype) + setattr(parsed_args, attr, dtype) return parsed_args @@ -79,6 +85,29 @@ def add_model_options(parser: argparse.ArgumentParser): help="Skips export decode", action="store_true", ) + parser.add_argument( + "--use-hf", + action="store_true", + default=False, + ) + parser.add_argument( + "--activation-dtype", + help="DType to use for activations in the model", + default="float16", + ) + parser.add_argument( + "--attention-dtype", + help="DType to use for activations in the model", + default="float16", + ) + parser.add_argument("--device", help="Torch device (or default)") + + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="Number of devices for tensor parallel sharding. Will be overridden by dataset.properties if present", + ) def add_quantization_options(parser: argparse.ArgumentParser):