From 7e62c25a3f307a9a0e3191fc9ebdd875aacab1fc Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:12:42 -0600 Subject: [PATCH] loop over cache_partitions to enable fusion (#677) Co-authored-by: Rob Suderman --- sharktank/sharktank/layers/kv_cache.py | 52 ++++++++------------------ 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index c73b7a8f4..46e94ff90 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -447,30 +447,25 @@ def write_timestep( bs, *_ = seq_positions.shape assert len(cache_partitions) == self.cache_partition_count - partition_count = len(cache_partitions) + # [bs, 1, atten_head_count, attn_head_dim] + for idx, cache_partition in enumerate(cache_partitions): + # [bs, 1] + page_index = seq_positions // self.block_seq_stride - # [bs, partitions, atten_head_count, attn_head_dim] - cache_partitions = ops.cat(cache_partitions, dim=1) + page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) + page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) - # [bs, 1] - page_index = seq_positions // self.block_seq_stride + # [1, 1] + partitions = torch.tensor(idx).unsqueeze(0) - page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) - page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) - - # [1, partitions] - partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0) - - # [bs, partitions] - page_id = page_id.repeat(1, partition_count) - transformer_block = torch.full( - (bs, partition_count), transformer_block_index, device=device - ) - page_offset = page_offset.repeat(1, partition_count) - partitions = partitions.repeat(bs, 1) + # [bs, 1] + transformer_block = torch.full( + (bs, 1), transformer_block_index, device=device + ) + partitions = partitions.repeat(bs, 1) - indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_(indices=indices, values=cache_partitions) + indices = (page_id, transformer_block, partitions, page_offset) + page_table.index_put_(indices=indices, values=cache_partition) return @@ -490,14 +485,6 @@ def write( page_table = self.unflatten_page_table(state) # 6D bs, block_seq_len, *_ = page_ids.shape - # Blocks dim 1,2 according to the configured block stride. - blocked_shape = [ - bs, - block_seq_len, - self.block_seq_stride, - self.attn_head_count, - self.attn_head_dim, - ] # Reshape the page cache into sub-blocks so that we can index at the # granularity of the transformer_block and cache partition. @@ -513,21 +500,14 @@ def write( transformer_block_index * transformer_block_stride ) - part_block_views = [] - subblock_ids_kv = [] for index, partition in enumerate(cache_partitions): part_block_view = partition.unflatten( 1, (block_seq_len, self.block_seq_stride) ) part_block_view = part_block_view.flatten(0, 1) - part_block_views.append(part_block_view) subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) - subblock_ids_kv.append(subblock_ids) - - subblock_ids = ops.cat(subblock_ids_kv) - part_block_view = ops.cat(part_block_views, dim=0) - subblock_table.index_copy_(0, subblock_ids, part_block_view) + subblock_table.index_copy_(0, subblock_ids, part_block_view)