Skip to content

Commit

Permalink
[sharktank] Update block_seq_stride for perplexity CI tests (#707)
Browse files Browse the repository at this point in the history
- Update `block_seq_stride` for perplexity CI tests
- Update default value of `block_seq_stride` from `16` to `32` in
`export_paged_llm_v1.py`
  • Loading branch information
archana-ramalingam authored Dec 17, 2024
1 parent aab7161 commit c4a592a
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 27 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
test_perplexity_iree:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
timeout-minutes: 1000
name: "Perplexity-IREE"
name: "IREE Perplexity"
strategy:
matrix:
version: [3.11]
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
test_perplexity_torch:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
timeout-minutes: 1000
name: "Perplexity-Torch"
name: "Torch Perplexity"
strategy:
matrix:
version: [3.11]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_eval_short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ concurrency:

jobs:
test_perplexity_iree:
name: "Llama3.1 8B FP16"
name: "IREE Perplexity"
strategy:
matrix:
version: [3.11]
Expand Down
1 change: 1 addition & 0 deletions app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes):
"python",
"-m",
"sharktank.examples.export_paged_llm_v1",
"--block-seq-stride=16",
f"--{model_path.suffix.strip('.')}-file={model_path}",
f"--output-mlir={mlir_path}",
f"--output-config={config_path}",
Expand Down
47 changes: 28 additions & 19 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def __init__(
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
block_seq_stride,
):
self.torch_device = torch_device
self.iree_device = iree_device
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.kv_cache_type = kv_cache_type
self.block_seq_stride = block_seq_stride
self.activation_dtype = torch.float16
self.attention_dtype = torch.float16
self.tensor_parallelism_size = tensor_parallelism_size
Expand Down Expand Up @@ -136,6 +138,7 @@ def compile_model(self, weight_path_str):
iree_hal_target_backends=self.iree_hal_target_backends,
attention_kernel=self.attention_kernel,
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=self.block_seq_stride,
)
vmfb_path = export_artifacts.get_artifacts()
return vmfb_path
Expand All @@ -145,7 +148,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path):

self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(weight_path.properties),
block_seq_stride=16,
block_seq_stride=self.block_seq_stride,
kv_cache_type=self.kv_cache_type,
device=self.torch_device,
activation_dtype=self.activation_dtype,
Expand Down Expand Up @@ -394,6 +397,7 @@ def run_perplexity(
tensor_parallelism_size,
attention_kernel,
num_prompts,
block_seq_stride,
):
start = time.time()
perplexity = Perplexity(
Expand All @@ -404,6 +408,7 @@ def run_perplexity(
kv_cache_type=kv_cache_type,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=attention_kernel,
block_seq_stride=block_seq_stride,
)

perplexity.get_prompts(num_prompts=num_prompts)
Expand All @@ -425,8 +430,18 @@ def run_perplexity(

def main(argv):
parser = cli.create_parser()
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--torch-device", help="Torch device (or default)")
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
type=int,
default=32,
)
parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
parser.add_argument(
"--iree-hip-target",
Expand All @@ -440,48 +455,42 @@ def main(argv):
default="rocm",
help="Specify the iree-hal target backends (e.g., rocm)",
)
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test (1 to 100)",
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding",
)
parser.add_argument(
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test",
)
parser.add_argument("--torch-device", help="Torch device (or default)")

cli.add_tokenizer_options(parser)
cli.add_input_dataset_options(parser)
args = cli.parse(parser, args=argv)

torch_device = torch.device(args.torch_device) if args.torch_device else None
iree_device = args.iree_device
kv_cache_type = args.kv_cache_type
weight_path = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
weight_path_str = str(args.irpa_file)

ppl = run_perplexity(
weight_path=weight_path,
weight_path_str=weight_path_str,
weight_path_str=str(args.irpa_file),
tokenizer=tokenizer,
torch_device=torch_device,
iree_device=iree_device,
iree_device=args.iree_device,
iree_hip_target=args.iree_hip_target,
iree_hal_target_backends=args.iree_hal_target_backends,
kv_cache_type=kv_cache_type,
kv_cache_type=args.kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
block_seq_stride=args.block_seq_stride,
)

logger.info(f"\n{json.dumps(ppl, indent=2)}")
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
"--block-seq-stride",
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
type=int,
default="16",
default=32,
)
parser.add_argument(
"--verbose",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class LlamaModelConfig:

# Block sequence stride for a paged KV cache. This must divide evenly
# into the context length.
block_seq_stride: int = 16
block_seq_stride: int = 32

# Either "paged" or "direct".
kv_cache_type: str = "paged"
Expand Down
5 changes: 2 additions & 3 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
iree_hal_target_backends: str,
attention_kernel: str,
tensor_parallelism_size: int,
block_seq_stride: Optional[int] = None,
block_seq_stride: int,
):
self.sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
Expand Down Expand Up @@ -180,14 +180,13 @@ def export_to_mlir(
f"--output-mlir={mlir_path}",
f"--output-config={json_path}",
f"--bs={str(self.batch_size)}",
f"--block-seq-stride={self.block_seq_stride}",
]
if skip_decode:
export_args.append("--skip-decode")
if self.attention_kernel in ["decomposed", "torch"]:
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
if self.block_seq_stride:
export_args.append(f"--block-seq-stride={self.block_seq_stride}")

cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(export_args)
Expand Down

0 comments on commit c4a592a

Please sign in to comment.