Skip to content

Commit

Permalink
loop over cache_partitions to enable fusion (#677)
Browse files Browse the repository at this point in the history
Co-authored-by: Rob Suderman <[email protected]>
  • Loading branch information
dan-garvey and rsuderman authored Dec 11, 2024
1 parent 1e26b20 commit 7e62c25
Showing 1 changed file with 16 additions and 36 deletions.
52 changes: 16 additions & 36 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,30 +447,25 @@ def write_timestep(
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count

partition_count = len(cache_partitions)
# [bs, 1, atten_head_count, attn_head_dim]
for idx, cache_partition in enumerate(cache_partitions):
# [bs, 1]
page_index = seq_positions // self.block_seq_stride

# [bs, partitions, atten_head_count, attn_head_dim]
cache_partitions = ops.cat(cache_partitions, dim=1)
page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)

# [bs, 1]
page_index = seq_positions // self.block_seq_stride
# [1, 1]
partitions = torch.tensor(idx).unsqueeze(0)

page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)

# [1, partitions]
partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0)

# [bs, partitions]
page_id = page_id.repeat(1, partition_count)
transformer_block = torch.full(
(bs, partition_count), transformer_block_index, device=device
)
page_offset = page_offset.repeat(1, partition_count)
partitions = partitions.repeat(bs, 1)
# [bs, 1]
transformer_block = torch.full(
(bs, 1), transformer_block_index, device=device
)
partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partitions)
indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)

return

Expand All @@ -490,14 +485,6 @@ def write(
page_table = self.unflatten_page_table(state) # 6D

bs, block_seq_len, *_ = page_ids.shape
# Blocks dim 1,2 according to the configured block stride.
blocked_shape = [
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
]

# Reshape the page cache into sub-blocks so that we can index at the
# granularity of the transformer_block and cache partition.
Expand All @@ -513,21 +500,14 @@ def write(
transformer_block_index * transformer_block_stride
)

part_block_views = []
subblock_ids_kv = []
for index, partition in enumerate(cache_partitions):
part_block_view = partition.unflatten(
1, (block_seq_len, self.block_seq_stride)
)
part_block_view = part_block_view.flatten(0, 1)
part_block_views.append(part_block_view)

subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
).flatten(0, 1)
subblock_ids_kv.append(subblock_ids)

subblock_ids = ops.cat(subblock_ids_kv)
part_block_view = ops.cat(part_block_views, dim=0)

subblock_table.index_copy_(0, subblock_ids, part_block_view)
subblock_table.index_copy_(0, subblock_ids, part_block_view)

0 comments on commit 7e62c25

Please sign in to comment.