Skip to content

Commit

Permalink
Merge branch 'main' into fix-ppl-test
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Jan 24, 2025
2 parents 7893014 + 782ec5b commit a8e9699
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-sdxl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ env:
jobs:
install-and-test:
name: Install and test
runs-on: mi300x-4
runs-on: mi300x-3

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: mi300x-4
runs-on: mi300x-3
defaults:
run:
shell: bash
Expand Down Expand Up @@ -97,7 +97,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: mi300x-4
runs-on: mi300x-3
defaults:
run:
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: mi300x-4
runs-on: mi300x-3
defaults:
run:
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pkgci_shark_ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: mi300x-4
runs-on: mi300x-3
# runs-on: ubuntu-latest # everything else works but this throws an "out of resources" during model loading
# TODO: make a copy of this that runs on standard runners with tiny llama instead of a 8b model
defaults:
Expand Down
39 changes: 14 additions & 25 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def __init__(
attention_kernel,
block_seq_stride,
use_attention_mask,
activation_dtype=torch.float16,
attention_dtype=torch.float16,
):
self.torch_device = torch_device
self.iree_device = iree_device
self.iree_hip_target = iree_hip_target
self.iree_hal_target_device = iree_hal_target_device
self.kv_cache_type = kv_cache_type
self.block_seq_stride = block_seq_stride
self.activation_dtype = torch.float16
self.attention_dtype = torch.float16
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel
self.use_attention_mask = use_attention_mask
Expand Down Expand Up @@ -435,18 +437,6 @@ def run_perplexity(

def main(argv):
parser = cli.create_parser()
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch"],
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
type=int,
default=32,
)
parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
parser.add_argument(
"--iree-hip-target",
Expand All @@ -460,31 +450,31 @@ def main(argv):
default="hip",
help="Specify the iree-hal target device (e.g., hip, cpu)",
)
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument(
"--num-prompts",
type=int,
default=100,
help="Number of prompts for perplexity test (1 to 100)",
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="Number of devices for tensor parallel sharding",
)
parser.add_argument("--torch-device", help="Torch device (or default)")

cli.add_model_options(parser)
cli.add_tokenizer_options(parser)
cli.add_input_dataset_options(parser)
args = cli.parse(parser, args=argv)

torch_device = torch.device(args.torch_device) if args.torch_device else None
torch_device = torch.device(args.device) if args.device else None
weight_path = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

use_attention_mask = True

# Override flag if dataset disagrees
tensor_parallelism_size = (
weight_path.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in weight_path.properties
else args.tensor_parallelism_size
)

ppl = run_perplexity(
weight_path=weight_path,
weight_path_str=str(args.irpa_file),
Expand All @@ -493,8 +483,7 @@ def main(argv):
iree_device=args.iree_device,
iree_hip_target=args.iree_hip_target,
iree_hal_target_device=args.iree_hal_target_device,
kv_cache_type=args.kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
block_seq_stride=args.block_seq_stride,
Expand Down
6 changes: 1 addition & 5 deletions sharktank/sharktank/evaluate/perplexity_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ class Perplexity_torch:
def __init__(
self,
device,
kv_cache_type,
activation_dtype=torch.float32,
attention_dtype=torch.float32,
):
self.device = device
self.kv_cache_type = kv_cache_type
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype

Expand Down Expand Up @@ -115,7 +113,6 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern
self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=self.kv_cache_type,
device=self.device,
activation_dtype=self.activation_dtype,
attention_dtype=self.attention_dtype,
Expand Down Expand Up @@ -298,14 +295,13 @@ def run_perplexity_torch(
dataset,
tokenizer,
device,
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
num_prompts,
):
start = time.time()

perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type)
perplexity = Perplexity_torch(device=device)
perplexity.get_prompts(num_prompts=num_prompts)
perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel)
ppl = perplexity.get_perplexity()
Expand Down
6 changes: 0 additions & 6 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ def main():
type=lambda arg: [int(bs) for bs in arg.split(",")],
default="4",
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
type=int,
default=32,
)
parser.add_argument(
"--verbose",
help="Include verbose logging",
Expand Down
18 changes: 6 additions & 12 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
):
self.model = model
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
if self.model.config.kv_cache_type == "paged":
self.shared_cache_state = model.cache.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
Expand All @@ -63,18 +63,18 @@ def begin_batch(self, prompts: list[str]):
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
cache_state = self.model.cache.direct.allocate(bs=len(prompts))
cache_state = self.model.cache.allocate(bs=len(prompts))
return Batch(self, token_ids, seq_lens, cache_state)

def alloc_page(self) -> int:
if self.model.cache.is_direct:
if self.model.config.kv_cache_type == "direct":
# We don't allocate block ids for the direct cache.
return 0

return self.free_pages.pop()

def release_page(self, index: int):
if self.model.cache.is_direct:
if self.model.config.kv_cache_type == "direct":
return
self.free_pages.append(index)

Expand Down Expand Up @@ -238,12 +238,6 @@ def main():
"--save_intermediates_path",
help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors",
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
cli.add_quantization_options(parser)
Expand All @@ -255,7 +249,7 @@ def main():
prompts = args.prompt
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
block_seq_stride=args.block_seq_stride,
device=device,
activation_dtype=args.activation_dtype,
attention_dtype=args.activation_dtype,
Expand Down
27 changes: 8 additions & 19 deletions sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,14 @@ def paged_attention(
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride

if attention_block.cache.is_paged:
xk, xv = attention_block.transact_cache_paged(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
)
elif attention_block.cache.is_direct:
xk, xv = attention_block.transact_cache_direct(
xk_cache_update=xk,
xv_cache_update=xv,
start_positions=start_positions,
kv_seq_len=kv_seq_len,
cache_state=cache_state,
)
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}")
xk, xv = attention_block.transact_cache(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
)

# Expand kv heads for GQA.
gqa_n_rep = attention_block.head_count // attention_block.head_count_kv
Expand Down
6 changes: 6 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def add_model_options(parser: argparse.ArgumentParser):
default=1,
help="Number of devices for tensor parallel sharding. Will be overridden by dataset.properties if present",
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
type=int,
default=32,
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
20 changes: 10 additions & 10 deletions sharktank/sharktank/utils/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def begin_batch(
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)

if self.model.cache.is_paged:
cache_state = self.model.cache.paged.allocate(page_cache_size)
if self.model.config.kv_cache_type == "paged":
cache_state = self.model.cache.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
cache_state = self.model.cache.direct.allocate(bs=len(prompts))
elif self.model.config.kv_cache_type == "direct":
cache_state = self.model.cache.allocate(bs=1)
return Batch(self, token_ids, seq_lens, cache_state)

def begin_eval_batch(
Expand All @@ -59,22 +59,22 @@ def begin_eval_batch(
bs: int,
page_cache_size: int = 128,
):
if self.model.cache.is_paged:
cache_state = self.model.cache.paged.allocate(page_cache_size)
if self.model.config.kv_cache_type == "paged":
cache_state = self.model.cache.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
cache_state = self.model.cache.direct.allocate(bs=bs)
elif self.model.config.kv_cache_type == "direct":
cache_state = self.model.cache.allocate(bs=1)
return Batch(self, token_batch, seq_lens_batch, cache_state)

def alloc_page(self) -> int:
if self.model.cache.is_direct:
if self.model.config.kv_cache_type == "direct":
# We don't allocate block ids for the direct cache.
return 0

return self.free_pages.pop()

def release_page(self, index: int):
if self.model.cache.is_direct:
if self.model.config.kv_cache_type == "direct":
return
self.free_pages.append(index)

Expand Down
8 changes: 4 additions & 4 deletions sharktank/tests/evaluate/perplexity_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_llama3_8B_f16(self):
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
f"--attention-kernel=torch",
f"--num-prompts={self.batch_size}",
]
)
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_llama3_8B_fp8(self):
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
f"--attention-kernel=torch",
f"--num-prompts={self.batch_size}",
]
)
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_llama3_405B_f16(self):
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
f"--attention-kernel=torch",
f"--num-prompts={self.batch_size}",
]
)
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_llama3_405B_fp8(self):
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
f"--attention-kernel=torch",
f"--num-prompts={self.batch_size}",
]
)
Expand Down

0 comments on commit a8e9699

Please sign in to comment.