Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement get_freqs for RopeWithAttentionSink #7100

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

from typing import Optional

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
Expand All @@ -23,12 +25,37 @@ class RopeWithAttentionSink(Rope):
in KVCache instead of positions in the actual text.
"""

def __init__(self, params: ModelArgs):
def __init__(
self,
params: ModelArgs,
window_size: int,
sink_size: int,
eviction_batch_size: int,
):
super().__init__(params)
if self.params.use_hf_rope:
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
self.max_seq_length = window_size + sink_size
assert self.max_seq_length == self.params.max_seq_len
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
assert input_pos is not None

input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
self.eviction_batch_size,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return super().get_freqs(input_pos + self.position_shift, seq_len)

def rerotate_k(
self,
Expand Down
51 changes: 49 additions & 2 deletions examples/models/llama/source_transformation/test_attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,57 @@

class RopeWithAttentionSinkTest(unittest.TestCase):

def _init_rope(self, params: ModelArgs, eviction_batch_size: int):
return RopeWithAttentionSink(
params=params,
window_size=252,
sink_size=4,
eviction_batch_size=eviction_batch_size,
)

def setUp(self):
torch.manual_seed(42)
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params)
self.params = ModelArgs(
use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256
)
self.rope_with_attention_sink = self._init_rope(
params=self.params, eviction_batch_size=1
)

@parameterized.expand(
[
[0, 10, 1, 0], # No shift
[250, 10, 1, 246], # Some shift
[256, 10, 1, 246], # All shift
[0, 10, 30, 0], # No shift with batch eviction
[250, 10, 30, 220], # Some shift with batch eviction
[256, 10, 30, 226], # All shift with batch eviction
]
)
def test_get_freqs(
self, input_pos, seq_len, eviction_batch_size, expected_result_pos
):
self.rope_with_attention_sink = self._init_rope(
params=self.params, eviction_batch_size=eviction_batch_size
)

freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
input_pos=torch.tensor([input_pos], dtype=torch.int32),
seq_len=seq_len,
)

torch.testing.assert_close(
freqs_cos,
self.rope_with_attention_sink.freqs_cos.narrow(
0, expected_result_pos, seq_len
),
)
torch.testing.assert_close(
freqs_sin,
self.rope_with_attention_sink.freqs_sin.narrow(
0, expected_result_pos, seq_len
),
)

@parameterized.expand(
[
Expand Down
Loading