diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index cb28371bb..186fb5154 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -236,7 +236,7 @@ def main(): model = PagedLlamaAttentionBlock( theta=attention_block_theta, block_index=0, - cache=create_kv_cache(llama_config), + cache=create_paged_kv_cache(llama_config), head_count=llama_config.hp.attention_head_count, head_dim=llama_config.hp.attn_head_dim, head_count_kv=llama_config.hp.attention_head_count_kv, diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 3caf7631d..9842a8291 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -6,7 +6,7 @@ from .base import BaseLayer, ThetaLayer from .conv import Conv2DLayer -from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache +from .kv_cache import PagedKVCache from .causal_llm import BaseCausalLMModel from .linear import LinearLayer from .norm import RMSNormLayer, LayerNorm diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index f62002f46..6af0e5183 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -22,204 +22,10 @@ from ..types import SplitPrimitiveTensor, ReplicatedTensor from .. import ops -__all__ = [ - "BaseKVCache", - "DirectKVCache", - "PagedKVCache", -] +__all__ = ["PagedKVCache"] -class BaseKVCache(abc.ABC): - """Base class for a KV cache. - - This doesn't do much on its own except to serve as a type-safe base class - unifying the PagedKVCache and DirectKVCache: - - * PagedKVCache is a shared cache which can be used across an arbitrary - number of batches/sequences with random mapping of blocks within a - sequence to backing "page". - * DirectKVCache is a single-batch cache with a fixed batch size and - sequence length where the K/V cache tensors for each transformer block - are densely layed out in memory. - """ - - block_seq_stride: int - transformer_block_count: int - attn_head_count: int - attn_head_dim: int - - @property - @abc.abstractmethod - def pad_sequence_stride(self) -> int: - """Stride that a sequence must be padded to in order to be valid for - the cache. For paged caches, this will typically be a multiple of the - block_seq_stride. For direct caches it may be 1 or a multiple that - is chosen for performance reasons. - """ - ... - - @property - def is_paged(self) -> bool: - return isinstance(self, PagedKVCache) - - @property - def is_direct(self) -> bool: - return isinstance(self, DirectKVCache) - - @property - def paged(self) -> "PagedKVCache": - assert isinstance( - self, PagedKVCache - ), f"Attempt to access cache {type(self)} as paged but it is not" - return self - - @property - def direct(self) -> "DirectKVCache": - assert isinstance( - self, DirectKVCache - ), f"Attempt to access cache {type(self)} as direct but it is not" - return self - - -class DirectKVCache(BaseKVCache): - """KVCache for a single batch where the cache tensors are densely laid out.""" - - def __init__( - self, - *, - block_seq_stride: int, - transformer_block_count: int, - attn_head_count: int, - attn_head_dim: int, - seq_length: int, - shard_count: int = 1, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - ): - self.block_seq_stride = block_seq_stride - self.transformer_block_count = transformer_block_count - self.attn_head_count = attn_head_count - self.attn_head_dim = attn_head_dim - self.seq_length = seq_length - self.shard_count = shard_count - self.device = device - self.dtype = dtype - - @property - def pad_sequence_stride(self) -> int: - return self.block_seq_stride - - def allocate(self, *, bs: int) -> list[torch.Tensor]: - """Allocates 2*transformer_block_count K/V cache tensors for the - given batch size and sequence length. - - Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] - """ - allocations = [ - torch.empty( - [ - bs, - self.seq_length, - self.attn_head_count, - self.attn_head_dim, - ], - dtype=self.dtype, - device=self.device, - ) - for _ in range(2 * self.transformer_block_count) - ] - - if self.shard_count == 1: - return allocations - - return [ - ops.reshard_split(allocation, dim=2, count=self.shard_count) - for allocation in allocations - ] - - def read( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - transformer_block_index: int, - seq_len: int, - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Reads cache partitions from the page table for the given page_ids. - - Args: - state: State struct as returned from allocate(). - read_into_partitions: List of cache partitions to read into in-place. - transformer_block_index: The index of the transformer block accessing - the cache. - page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids - to access. - - Returns a tuple of cache partitions (i.e. k and v caches for the transformer - block), linearized. Note that this reference approach to reading by - materializing linearly may not be terribly efficient unless if the - compiler can fuse the gather. - """ - read_count = len(read_into_partitions) - reads = [] - for i in range(read_count): - reads.append( - state[transformer_block_index * read_count + i][:, :seq_len, :, :] - ) - - return tuple(reads) - - def write_timestep( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - transformer_block_index: int, - # [bs] - seq_positions: Union[torch.Tensor, ReplicatedTensor], - # [bs, max_seqlen // block_pos_stride] - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Writes a single batched timestep across all cache partitions. - - Note that this internally loops over the batch size, which cannot be - dynamic. - """ - bs, _, _, _ = cache_partitions[0].shape - update_count = len(cache_partitions) - - for b in range(bs): - row_index = torch.tensor([b], dtype=torch.int64) - row_start_pos = seq_positions[row_index].unsqueeze(0) - - for i, update in enumerate(cache_partitions): - cache = state[transformer_block_index * update_count + i] - cache.index_put_((row_index, row_start_pos), update[row_index, 0]) - - def write( - self, - state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], - *, - transformer_block_index: int, - page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, - ): - """Writes cache partitions from a linear layout to the page table. - - This is the inverse of the linear read. The same caveat applies if the - in-place scatter cannot be fused. - """ - update_count = len(cache_partitions) - - for idx, update_src in enumerate(cache_partitions): - cache_dest = state[transformer_block_index * update_count + idx] - _, batch_seq_len, _, _ = update_src.shape - cache_dest[:, :batch_seq_len, :, :] = update_src - - -class PagedKVCache(BaseKVCache): +class PagedKVCache: """Implementation of a KV cache on top of a 'page table'. The page table slab is physically represented as a 2D tensor: diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 0cdb5d713..38cb7c0fa 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Optional +import math import torch import torch.nn.functional as F @@ -29,7 +30,6 @@ def __init__( head_count: int, head_dim: int, head_count_kv: int, - embedding: RotaryEmbeddingLayer, rms_epsilon: float, ): super().__init__(theta) @@ -41,7 +41,6 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - self.embedding = embedding self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv @@ -50,6 +49,7 @@ def forward( self, h: torch.Tensor, *, + embedding: RotaryEmbeddingLayer, cache_k: torch.Tensor, cache_v: torch.Tensor, start_index: int, @@ -72,11 +72,11 @@ def forward( # Fast path to start_index based embedding lookup if available. # Falls back to a slower position based index lookup. if start_index is not None: - xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + xq = embedding.forward(xt=xq, start_index=start_index) + xk = embedding.forward(xt=xk, start_index=start_index) else: - xq, xk = embedding.apply_batched_mask( - xq=xq, xk=xk, mask=embedding_batch_mask - ) + xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask) + xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -108,9 +108,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( - self.head_dim - ) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 077e4e064..0be9ede05 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -59,7 +59,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6fef6704e..0bb6985e7 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -74,7 +74,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.use_hf = config.use_hf self.attention_kernel = config.attention_kernel diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index e2995dfde..b597c6e99 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -61,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = create_kv_cache(self.config) + self.cache = create_paged_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index c1691c8a8..f462d9c00 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -7,28 +7,18 @@ from ..layers import * -def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache: +def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: + if config.kv_cache_type != "paged": + raise ValueError("Model does not use paged kv cache, cannot create kv cache") + hp = config.hp - if config.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=config.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=config.device, - dtype=config.attention_dtype, - ) - elif config.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=config.block_seq_stride, - device=config.device, - dtype=config.attention_dtype, - shard_count=config.tensor_parallelism_size, - ) - else: - raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}") + return PagedKVCache( + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=config.block_seq_stride, + device=config.device, + dtype=config.attention_dtype, + shard_count=config.tensor_parallelism_size, + ) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 65b42c986..8512c8768 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -13,230 +13,6 @@ from sharktank.types import * -def test_direct(): - bs = 4 - seq_length = 24 - attn_head_count = 4 - attn_head_dim = 16 - transformer_block_count = 4 - cache = DirectKVCache( - block_seq_stride=4, - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - attn_head_dim=attn_head_dim, - seq_length=seq_length, - dtype=torch.float32, - device=None, - ) - - allocation = cache.allocate(bs=bs) - allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] - - write_seq_length = seq_length - 5 - - # Write a prefill in: - write_ones = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 - ) - write_twos = torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 - ) - cache.write( - allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 - ) - - # Check the written values have updated: - read_empty = [ - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length, - ) - torch.testing.assert_close(write_ones, read_back[0]) - torch.testing.assert_close(write_twos, read_back[1]) - - # Check the others are still zero: - for i in range(transformer_block_count): - if i == 1: - continue - read_ones = [ - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_ones = cache.read( - allocation, - read_into_partitions=read_ones, - transformer_block_index=i, - seq_len=write_seq_length, - ) - torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) - torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) - - # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 - ) - write_fours = torch.full( - (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 - ) - write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) - cache.write_timestep( - allocation, - cache_partitions=[write_threes, write_fours], - transformer_block_index=1, - seq_positions=write_pos, - ) - - read_empty = [ - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length + 1, - ) - - check_concat_0 = torch.concat([write_ones, write_threes], dim=1) - check_concat_1 = torch.concat([write_twos, write_fours], dim=1) - - torch.testing.assert_close(check_concat_0, read_back[0]) - torch.testing.assert_close(check_concat_1, read_back[1]) - - -def test_sharded_direct(): - bs = 4 - seq_length = 24 - attn_head_count = 8 - attn_head_dim = 16 - transformer_block_count = 4 - shard_count = 4 - cache = DirectKVCache( - block_seq_stride=4, - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - attn_head_dim=attn_head_dim, - seq_length=seq_length, - shard_count=shard_count, - dtype=torch.float32, - device=None, - ) - - allocation = cache.allocate(bs=bs) - # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] - - write_seq_length = seq_length - 5 - - # Write a prefill in: - write_ones = reshard_split( - torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), - 1.0, - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - - write_twos = reshard_split( - torch.full( - (bs, write_seq_length, attn_head_count, attn_head_dim), - 2.0, - dtype=torch.float32, - ), - dim=2, - count=shard_count, - ) - - cache.write( - allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 - ) - - # Check the written values have updated: - read_empty = [ - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - torch.empty( - (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length, - ) - torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) - torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) - - # Write timestep - write_threes = reshard_split( - torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), - dim=2, - count=shard_count, - ) - write_fours = reshard_split( - torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), - dim=2, - count=shard_count, - ) - - write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count - ) - cache.write_timestep( - allocation, - cache_partitions=[write_threes, write_fours], - transformer_block_index=1, - seq_positions=write_pos, - ) - - read_empty = [ - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - torch.zeros( - (bs, write_seq_length + 1, attn_head_count, attn_head_dim), - dtype=torch.float32, - ), - ] - read_back = cache.read( - allocation, - read_into_partitions=read_empty, - transformer_block_index=1, - seq_len=write_seq_length + 1, - ) - - check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) - check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) - - torch.testing.assert_close(check_concat_0, unshard(read_back[0])) - torch.testing.assert_close(check_concat_1, unshard(read_back[1])) - - def test_paged(): bs = 4 seq_length = 24 diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index bbb52f235..e74a14ad5 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -57,7 +57,7 @@ def testExportDecomposed(self): dtype=dtype, ) - cache_state = cache.paged.allocate(self.page_count) + cache_state = cache.allocate(self.page_count) cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( @@ -140,7 +140,7 @@ def testExportNondecomposed(self): dtype=dtype, ) - cache_state = cache.paged.allocate(self.page_count) + cache_state = cache.allocate(self.page_count) cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) theta = make_llama_attention_block_theta( diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index 211fab5a0..22013635b 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -77,7 +77,7 @@ def test(self): input_tensor, embedding=attention_embedding, start_index=0, - cache_state=paged_kv_cache.paged.allocate(128), + cache_state=paged_kv_cache.allocate(128), seq_block_ids=torch.arange(seq_len).view(1, -1), ) diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index a80575951..3d43243b0 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn from sharktank.models.llama.llama import ( + LlamaAttentionBlock, PagedLlamaAttentionBlock, PagedKVCache, - DirectKVCache, ) from sharktank.models.llama.testing import * from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer @@ -48,15 +48,14 @@ def setUp(self): device=self.device, dtype=self.attention_dtype, ) - self.direct_kv_cache = DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=self.head_count, - attn_head_count=self.head_count, - attn_head_dim=self.head_dim, - seq_length=self.max_seq_len, - device=self.device, - dtype=self.attention_dtype, - ) + self.direct_k_cache = [ + torch.empty([self.bs, self.max_seq_len, self.head_count_kv, self.head_dim]) + for _ in range(self.block_count) + ] + self.direct_v_cache = [ + torch.empty([self.bs, self.max_seq_len, self.head_count_kv, self.head_dim]) + for _ in range(self.block_count) + ] self.attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, rope_freq_base=self.rope_freq_base, @@ -80,10 +79,8 @@ def setUp(self): ) self.direct_attn_blocks = nn.ModuleList( [ - PagedLlamaAttentionBlock( + LlamaAttentionBlock( theta=self.attention_block_theta, - block_index=n, - cache=self.direct_kv_cache, head_count=self.head_count, head_dim=self.head_dim, head_count_kv=self.head_count_kv, @@ -98,12 +95,6 @@ def setUp(self): [127], ] ) - self.direct_cache_state = self.direct_kv_cache.allocate(bs=1) - self.direct_seq_block_ids = torch.tensor( - [ - [0], - ] - ) self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( self.start_positions, batch_seq_len=1 ) @@ -139,8 +130,8 @@ def testDirectAndPagedKVCachePrefill(self): embedding=self.attention_embedding, start_index=0, attention_mask=self.prefill_attention_mask, - cache_state=self.direct_cache_state, - seq_block_ids=self.direct_seq_block_ids, + cache_k=self.direct_k_cache[block_idx], + cache_v=self.direct_v_cache[block_idx], ) page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) index_written = self.start_positions.item() @@ -150,13 +141,13 @@ def testDirectAndPagedKVCachePrefill(self): """ page_id = self.paged_seq_block_ids[0][0].item() """ - direct_cache_state is a list of num_transformer_blocks * 2 (one for K and one for V), - so here we index into the first transformer block's keys with self.direct_cache_state[0] - and the first transformer block's values with self.direct_cache_state[1]. Each row + direct_cache_state is a list of num_transformer_blocks (one for K and one for V), + so here we index into the first transformer block's keys with self.direct_k_cache[0] + and the first transformer block's values with self.direct_v_cache[0]. Each row in direct_cache_state is a tensor of [bs, seq_len , attn_heads, attn_dim], so we make sure the first 8 (start_position) tensors starting at sequence 0 of the seq_len are written to. """ - updated_direct_cache_state = self.direct_cache_state[0][ + updated_direct_k_cache_state = self.direct_k_cache[0][ :, :index_written ].squeeze(0) """ @@ -172,10 +163,10 @@ def testDirectAndPagedKVCachePrefill(self): first transformer block's K cache for the first 8 (start_positions) tensors starting at sequence 0. """ - updated_paged_cache_state = page_table[page_id][0, 0, :index_written] - assert updated_direct_cache_state.shape == updated_paged_cache_state.shape + updated_paged_k_cache_state = page_table[page_id][0, 0, :index_written] + assert updated_direct_k_cache_state.shape == updated_paged_k_cache_state.shape torch.testing.assert_close( - updated_direct_cache_state, updated_paged_cache_state + updated_direct_k_cache_state, updated_paged_k_cache_state ) paged_prefill_attn_output = paged_input_tensor @@ -246,20 +237,18 @@ def testDirectAndPagedKVCacheDecode(self): embedding=self.attention_embedding, embedding_batch_mask=self.embedding_batch_mask, attention_mask=decode_attention_mask, - cache_state=self.direct_cache_state, - seq_block_ids=self.direct_seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, + cache_k=self.direct_k_cache, + cache_v=self.direct_v_cache, ) page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) index_written = self.start_positions.item() page_id = self.paged_seq_block_ids[0][0].item() - updated_direct_cache_state_keys = self.direct_cache_state[0][ + updated_direct_cache_state_keys = self.direct_k_cache[0][ :, index_written ].squeeze(0) updated_paged_cache_state_keys = page_table[page_id][0, 0, index_written] - updated_direct_cache_state_values = self.direct_cache_state[1][ + updated_direct_cache_state_values = self.direct_v_cache[0][ :, index_written ].squeeze(0) updated_paged_cache_state_values = page_table[page_id][0, 1, index_written] diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..e78be1cbe 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -94,7 +94,7 @@ def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: seq_block_ids = torch.arange( self.batch_size * batch_seq_len // self.config.block_seq_stride ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = model.cache.allocate(page_count=self.cache_page_count) cache_state = [torch.rand_like(cache_state[0])] return OrderedDict( [ @@ -109,14 +109,14 @@ def make_equal_unsharded_and_sharded_prefill_args( self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: prefill_kwargs = self.make_prefill_args(model) - sharded_cache_state = sharded_model.cache.paged.allocate( + sharded_cache_state = sharded_model.cache.allocate( page_count=self.cache_page_count ) assert iterables_equal( prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape ) sharded_prefill_kwargs = deepcopy(prefill_kwargs) - sharded_cache_state = sharded_model.cache.paged.shard_state( + sharded_cache_state = sharded_model.cache.shard_state( sharded_prefill_kwargs["cache_state"] ) sharded_prefill_kwargs["cache_state"] = sharded_cache_state @@ -149,7 +149,7 @@ def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: seq_block_ids = torch.arange( self.batch_size * batch_seq_len // self.config.block_seq_stride ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = model.cache.allocate(page_count=self.cache_page_count) cache_state = [torch.rand_like(cache_state[0])] return OrderedDict( [ @@ -166,7 +166,7 @@ def make_equal_unsharded_and_sharded_decode_args( ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: decode_kwargs = self.make_decode_args(model) sharded_decode_kwargs = deepcopy(decode_kwargs) - sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state( + sharded_decode_kwargs["cache_state"] = sharded_model.cache.shard_state( sharded_decode_kwargs["cache_state"] ) @@ -203,7 +203,7 @@ def testCompareToySizedModelToUnsharded(self): ) expected_cache_state = prefill_kwargs["cache_state"][0] actual_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( + sharded_model.cache.unflatten_page_table( sharded_prefill_kwargs["cache_state"] ) ).flatten(start_dim=1) @@ -224,7 +224,7 @@ def testCompareToySizedModelToUnsharded(self): ) expected_decode_cache_state = decode_kwargs["cache_state"][0] actual_decode_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( + sharded_model.cache.unflatten_page_table( sharded_decode_kwargs["cache_state"] ) ).flatten(start_dim=1)