Skip to content

Commit

Permalink
add attention_sink.py
Browse files Browse the repository at this point in the history
add KVCacheWithAttentionSink

Pull Request resolved: #6579

This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens.

Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge.
ghstack-source-id: 255715047
@exported-using-ghexport

Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/)

Co-authored-by: Lunwen He <[email protected]>
  • Loading branch information
pytorchbot and helunwencser authored Dec 2, 2024
1 parent 2d499b3 commit 9d084c4
Show file tree
Hide file tree
Showing 2 changed files with 586 additions and 8 deletions.
121 changes: 120 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
Expand Down Expand Up @@ -87,3 +87,122 @@ def rerotate_k(
)

return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)


class KVCacheWithAttentionSink(KVCache):
"""
KV cache that supports attention sink. It keeps the initial few tokens as attention sink.
For other tokens, it uses a sliding window to keep the most recent tokens.
Parameters:
window_size: the size of the sliding window
sink_size: the number of initial tokens to keep as attention sink
eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache
"""

def __init__(
self,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
rope: RopeWithAttentionSink,
window_size: int,
sink_size: int,
eviction_batch_size: int,
max_batch_size: int = 1,
dtype=torch.float32,
):
super().__init__(
max_batch_size=max_batch_size,
max_seq_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
transpose_cache=transpose_cache,
enable_dynamic_shape=enable_dynamic_shape,
dtype=dtype,
)
self.rope = rope
self.window_size = window_size
self.sink_size = sink_size
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
"""
Evict old tokens from the cache to make rooms for new tokens.
Parameters:
input_pos: the start position of the incoming token in the actual sequence
seq_len: the length of the incoming sequence
rope: the rope object to use for rerotating k
Returns:
the number of tokens to evict from the cache which is also the number of
positions to shift for incoming tokens
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
self.eviction_batch_size,
)
num_to_keep = (
input_pos_item + self.position_shift - self.sink_size - num_to_evict
)
num_empty_space = self.window_size - num_to_keep
dim_to_slice = 2 if self.transpose_cache else 1
k_to_keep = self.k_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
)
if self.transpose_cache:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep.transpose(1, 2),
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
).transpose(1, 2)
else:
k_to_keep = self.rope.rerotate_k(
k=k_to_keep,
original_position=( # pyre-ignore [6]
self.sink_size + num_to_evict
),
new_position=self.sink_size,
)
self.k_cache = torch.cat(
[
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
k_to_keep,
torch.zeros_like(
self.k_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
)
self.v_cache = torch.cat(
[
self.v_cache.narrow(dim_to_slice, 0, self.sink_size),
self.v_cache.narrow(
dim_to_slice,
self.sink_size + num_to_evict, # pyre-ignore [6]
num_to_keep, # pyre-ignore [6]
),
torch.zeros_like(
self.v_cache.narrow(
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
)
),
],
dim=dim_to_slice,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return self.position_shift
Loading

0 comments on commit 9d084c4

Please sign in to comment.