Skip to content

Commit

Permalink
Implement get_freqs for RopeWithAttentionSink
Browse files Browse the repository at this point in the history
This PR implements the `get_freqs` function for `RopeWithAttentionSink`. It returns the `freqs_cos` and `freqs_sin` for given `input_pos` and `seq_len` after shifting tokens in the pre-computed `freqs_cos` and `freq_sin`.

Differential Revision: [D66525306](https://our.internmc.facebook.com/intern/diff/D66525306/)

ghstack-source-id: 255582545
Pull Request resolved: #7100
  • Loading branch information
helunwencser authored and kirklandsign committed Nov 27, 2024
1 parent 6b73841 commit 2590823
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
29 changes: 28 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

from typing import Optional

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
Expand All @@ -23,12 +25,37 @@ class RopeWithAttentionSink(Rope):
in KVCache instead of positions in the actual text.
"""

def __init__(self, params: ModelArgs):
def __init__(
self,
params: ModelArgs,
window_size: int,
sink_size: int,
eviction_batch_size: int,
):
super().__init__(params)
if self.params.use_hf_rope:
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
self.max_seq_length = window_size + sink_size
assert self.max_seq_length == self.params.max_seq_len
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
assert input_pos is not None

input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
self.eviction_batch_size,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return super().get_freqs(input_pos + self.position_shift, seq_len)

def rerotate_k(
self,
Expand Down
51 changes: 49 additions & 2 deletions examples/models/llama/source_transformation/test_attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,57 @@

class RopeWithAttentionSinkTest(unittest.TestCase):

def _init_rope(self, params: ModelArgs, eviction_batch_size: int):
return RopeWithAttentionSink(
params=params,
window_size=252,
sink_size=4,
eviction_batch_size=eviction_batch_size,
)

def setUp(self):
torch.manual_seed(42)
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params)
self.params = ModelArgs(
use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256
)
self.rope_with_attention_sink = self._init_rope(
params=self.params, eviction_batch_size=1
)

@parameterized.expand(
[
[0, 10, 1, 0], # No shift
[250, 10, 1, 246], # Some shift
[256, 10, 1, 246], # All shift
[0, 10, 30, 0], # No shift with batch eviction
[250, 10, 30, 220], # Some shift with batch eviction
[256, 10, 30, 226], # All shift with batch eviction
]
)
def test_get_freqs(
self, input_pos, seq_len, eviction_batch_size, expected_result_pos
):
self.rope_with_attention_sink = self._init_rope(
params=self.params, eviction_batch_size=eviction_batch_size
)

freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
input_pos=torch.tensor([input_pos], dtype=torch.int32),
seq_len=seq_len,
)

torch.testing.assert_close(
freqs_cos,
self.rope_with_attention_sink.freqs_cos.narrow(
0, expected_result_pos, seq_len
),
)
torch.testing.assert_close(
freqs_sin,
self.rope_with_attention_sink.freqs_sin.narrow(
0, expected_result_pos, seq_len
),
)

@parameterized.expand(
[
Expand Down

0 comments on commit 2590823

Please sign in to comment.