From c4a592ac8bcb2202a554ab1a4d311fdf5ddf28eb Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:35:11 -0800 Subject: [PATCH] [sharktank] Update block_seq_stride for perplexity CI tests (#707) - Update `block_seq_stride` for perplexity CI tests - Update default value of `block_seq_stride` from `16` to `32` in `export_paged_llm_v1.py` --- .github/workflows/ci_eval.yaml | 4 +- .github/workflows/ci_eval_short.yaml | 2 +- app_tests/integration_tests/llm/utils.py | 1 + .../sharktank/evaluate/perplexity_iree.py | 47 +++++++++++-------- .../sharktank/examples/export_paged_llm_v1.py | 2 +- .../sharktank/layers/configs/llm_configs.py | 2 +- sharktank/sharktank/utils/export_artifacts.py | 5 +- 7 files changed, 36 insertions(+), 27 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 3b85cb652..a71698774 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -24,7 +24,7 @@ jobs: test_perplexity_iree: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Perplexity-IREE" + name: "IREE Perplexity" strategy: matrix: version: [3.11] @@ -83,7 +83,7 @@ jobs: test_perplexity_torch: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Perplexity-Torch" + name: "Torch Perplexity" strategy: matrix: version: [3.11] diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index edaaee966..d5f8f5682 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -23,7 +23,7 @@ concurrency: jobs: test_perplexity_iree: - name: "Llama3.1 8B FP16" + name: "IREE Perplexity" strategy: matrix: version: [3.11] diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py index 80b5b3c09..dbbdee10d 100644 --- a/app_tests/integration_tests/llm/utils.py +++ b/app_tests/integration_tests/llm/utils.py @@ -90,6 +90,7 @@ def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes): "python", "-m", "sharktank.examples.export_paged_llm_v1", + "--block-seq-stride=16", f"--{model_path.suffix.strip('.')}-file={model_path}", f"--output-mlir={mlir_path}", f"--output-config={config_path}", diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index 6060eb91b..c47726f0e 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -68,12 +68,14 @@ def __init__( kv_cache_type, tensor_parallelism_size, attention_kernel, + block_seq_stride, ): self.torch_device = torch_device self.iree_device = iree_device self.iree_hip_target = iree_hip_target self.iree_hal_target_backends = iree_hal_target_backends self.kv_cache_type = kv_cache_type + self.block_seq_stride = block_seq_stride self.activation_dtype = torch.float16 self.attention_dtype = torch.float16 self.tensor_parallelism_size = tensor_parallelism_size @@ -136,6 +138,7 @@ def compile_model(self, weight_path_str): iree_hal_target_backends=self.iree_hal_target_backends, attention_kernel=self.attention_kernel, tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=self.block_seq_stride, ) vmfb_path = export_artifacts.get_artifacts() return vmfb_path @@ -145,7 +148,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), - block_seq_stride=16, + block_seq_stride=self.block_seq_stride, kv_cache_type=self.kv_cache_type, device=self.torch_device, activation_dtype=self.activation_dtype, @@ -394,6 +397,7 @@ def run_perplexity( tensor_parallelism_size, attention_kernel, num_prompts, + block_seq_stride, ): start = time.time() perplexity = Perplexity( @@ -404,6 +408,7 @@ def run_perplexity( kv_cache_type=kv_cache_type, tensor_parallelism_size=tensor_parallelism_size, attention_kernel=attention_kernel, + block_seq_stride=block_seq_stride, ) perplexity.get_prompts(num_prompts=num_prompts) @@ -425,8 +430,18 @@ def run_perplexity( def main(argv): parser = cli.create_parser() - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--torch-device", help="Torch device (or default)") + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for paged KV cache, must divide evenly into the context length", + type=int, + default=32, + ) parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", @@ -440,11 +455,12 @@ def main(argv): default="rocm", help="Specify the iree-hal target backends (e.g., rocm)", ) + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch_sdpa"], + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test (1 to 100)", ) parser.add_argument( "--tensor-parallelism-size", @@ -452,36 +468,29 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding", ) - parser.add_argument( - "--num-prompts", - type=int, - default=100, - help="Number of prompts for perplexity test", - ) + parser.add_argument("--torch-device", help="Torch device (or default)") cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) args = cli.parse(parser, args=argv) torch_device = torch.device(args.torch_device) if args.torch_device else None - iree_device = args.iree_device - kv_cache_type = args.kv_cache_type weight_path = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) - weight_path_str = str(args.irpa_file) ppl = run_perplexity( weight_path=weight_path, - weight_path_str=weight_path_str, + weight_path_str=str(args.irpa_file), tokenizer=tokenizer, torch_device=torch_device, - iree_device=iree_device, + iree_device=args.iree_device, iree_hip_target=args.iree_hip_target, iree_hal_target_backends=args.iree_hal_target_backends, - kv_cache_type=kv_cache_type, + kv_cache_type=args.kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, + block_seq_stride=args.block_seq_stride, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index ad297bcce..056d8a98e 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -49,7 +49,7 @@ def main(): "--block-seq-stride", help="Block sequence stride for paged KV cache, must divide evenly into the context length", type=int, - default="16", + default=32, ) parser.add_argument( "--verbose", diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 88f5c344c..6cf79402e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -144,7 +144,7 @@ class LlamaModelConfig: # Block sequence stride for a paged KV cache. This must divide evenly # into the context length. - block_seq_stride: int = 16 + block_seq_stride: int = 32 # Either "paged" or "direct". kv_cache_type: str = "paged" diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 0bf252525..75cdbab7a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -92,7 +92,7 @@ def __init__( iree_hal_target_backends: str, attention_kernel: str, tensor_parallelism_size: int, - block_seq_stride: Optional[int] = None, + block_seq_stride: int, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -180,14 +180,13 @@ def export_to_mlir( f"--output-mlir={mlir_path}", f"--output-config={json_path}", f"--bs={str(self.batch_size)}", + f"--block-seq-stride={self.block_seq_stride}", ] if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) - if self.block_seq_stride: - export_args.append(f"--block-seq-stride={self.block_seq_stride}") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args)