Skip to content

Commit

Permalink
[Llama] Change KVCache layout
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Jan 22, 2025
1 parent 8a48bbe commit 2909678
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 56 deletions.
123 changes: 68 additions & 55 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class PagedKVCache:
* transformer block
* cache partition (K or V cache)
* block sequence stride (number of sequence positions per block)
* attention heads
* block sequence stride (number of sequence positions per block)
* attention dimensionality
Note that the internal page structure matches the organization of the
Expand Down Expand Up @@ -74,8 +74,8 @@ def __init__(
self.sub_page_dims = [
self.transformer_block_count,
2,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.block_seq_stride,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
Expand All @@ -96,7 +96,7 @@ def unflatten_page_table(
shards = [
shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)
return SplitPrimitiveTensor(ts=shards, shard_dim=3)

def shard_state(
self, state: List[torch.Tensor]
Expand All @@ -114,13 +114,13 @@ def shard_state(
-1,
self.transformer_block_count,
2,
self.block_seq_stride,
self.attn_head_count,
self.block_seq_stride,
self.attn_head_dim,
]
)
sharded_page_table = ops.reshard_split(
page_table, dim=4, count=self.shard_count
page_table, dim=3, count=self.shard_count
)
shards = [
ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards
Expand Down Expand Up @@ -176,14 +176,6 @@ def read(
page_table = self.unflatten_page_table(state) # 6D

bs, block_seq_len, *_ = page_ids.shape
# Blocks dim 1,2 according to the configured block stride.
blocked_shape = [
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

# Reshape the page cache into sub-blocks so that we can index at the
# granularity of the transformer_block and cache partition.
Expand All @@ -209,8 +201,14 @@ def read_cache_partition(index: int):
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
selected = (
# Read Layout is: (bs, block_seq_len), kv_head_count, block_seq_stride, head_dim
# Output Layout is: bs, (block_seq_len, block_seq_stride), kv_head_count, head_dim
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
# bs, block_seq_len, kv_head_count, block_seq_stride, head_dim
.unflatten(0, (bs, block_seq_len))
# bs, block_seq_len, block_seq_stride, kv_head_count, head_dim
.transpose(2, 3)
# bs, (block_seq_len, block_seq_stride), kv_head_count, head_dim
.flatten(1, 2)
)
return selected
Expand All @@ -234,47 +232,55 @@ def write_timestep(
# [bs, max_seqlen // block_pos_stride]
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
Note that this internally loops over the batch size, which cannot be
dynamic.
"""
"""Writes a single batched timestep across all cache partitions."""
device = self.device
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape

# [bs, 1, atten_head_count, attn_head_dim]
for idx, cache_partition in enumerate([key, value]):
# [bs, 1]
page_index = seq_positions // self.block_seq_stride

page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)

# [1, 1]
if isinstance(seq_positions, ReplicatedTensor):
partitions = [
torch.tensor(idx).unsqueeze(0)
for _ in range(seq_positions.shard_count)
]
page_index = seq_positions // self.block_seq_stride
page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)

# This probably be written much better..., but after fighting
# a lot of with torch-mlir and dynamo, this is a hacky enough
# version that works. We are doing a lot of praying to the compiler
# here for scatter fusion.
# vLLM does this write using a custom kernel to write the key and
# value partitions.
#
# Here, we are trying to get the index to be a broadcasted version
# of the partition layout in which we are going to write in:
# [bs, kv_head_count, 1]
page_id = page_id.unsqueeze(-1).expand(bs, self.attn_head_count, 1)
head_id = (
torch.arange(0, self.attn_head_count, dtype=page_ids.dtype)
.unsqueeze(0)
.unsqueeze(-1)
.expand(bs, self.attn_head_count, 1)
)
page_offset = page_offset.unsqueeze(-1).expand(bs, self.attn_head_count, 1)

transformer_block = [
torch.full((bs, 1), transformer_block_index, device=device)
for _ in range(seq_positions.shard_count)
]
# This is a hack. Without this IREE generates a seperate dispatch for
# each layer, as well as fails to compiles for all dispatches except
# for layer 0.
block_idx = torch.full(
(bs, self.attn_head_count, 1), transformer_block_index, dtype=page_ids.dtype
)

partitions = ReplicatedTensor(ts=partitions)
transformer_block = ReplicatedTensor(ts=transformer_block)
else:
partitions = torch.tensor(idx).unsqueeze(0)
transformer_block = torch.full(
(bs, 1), transformer_block_index, device=device
)
for idx, cache_partition in enumerate([key, value]):
# Input Layout: bs, 1, kv_head_count, attn_head_dim
# Partition Layout: (bs, 1), kv_head_count, block_seq_stride, head_dim

partitions = partitions.repeat(bs, 1)
# Same hack as above.
part_idx = torch.full(
(bs, self.attn_head_count, 1), idx, dtype=page_ids.dtype
)

indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)
# bs, kv_head_count, 1, attn_head_dim
partition_view = cache_partition.transpose(1, 2)
page_table[
page_id, block_idx, part_idx, head_id, page_offset
] = partition_view

def write(
self,
Expand Down Expand Up @@ -308,15 +314,22 @@ def write(
transformer_block_index * transformer_block_stride
)

key_reshaped = key.unflatten(1, (block_seq_len, self.block_seq_stride)).flatten(
0, 1
)
value_reshaped = value.unflatten(
1, (block_seq_len, self.block_seq_stride)
).flatten(0, 1)
def reshape_input(x):
# Input Layout: bs, (block_seq_len, block_seq_stride), kv_head_count, head_dim
# Write Layout: (bs, block_seq_len), kv_head_count, block_seq_stride, head_dim

return (
x
# bs, block_seq_len, block_seq_stride, kv_head_count, head_dim
.unflatten(1, (block_seq_len, self.block_seq_stride))
# bs, block_seq_len, kv_head_count, block_seq_stride, head_dim
.transpose(2, 3)
# (bs, block_seq_len), kv_head_count, block_seq_stride, head_dim
.flatten(0, 1)
)

key_ids = base_subblock_ids.flatten(0, 1)
value_ids = base_subblock_ids.flatten(0, 1) + 1

subblock_table.index_copy_(0, key_ids, key_reshaped)
subblock_table.index_copy_(0, value_ids, value_reshaped)
subblock_table.index_copy_(0, key_ids, reshape_input(key))
subblock_table.index_copy_(0, value_ids, reshape_input(value))
1 change: 0 additions & 1 deletion sharktank/sharktank/utils/create_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache:
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
Expand Down

0 comments on commit 2909678

Please sign in to comment.