Skip to content

Commit

Permalink
update prefix cache at every iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed May 30, 2024
1 parent 30b9d0b commit 0c8e0e6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
4 changes: 2 additions & 2 deletions deepspeed/inference/v2/ragged/prefix_block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def lookup(self, tokens: torch.Tensor) -> torch.Tensor:
break
return cached_blocks

def extend(self, tokens: torch.Tensor, new_block_ids: List[int]) -> None:
def extend(self, tokens: torch.Tensor, new_block_ids: List[int], num_already_cached_blocks: int) -> None:
n_blocks = len(tokens) // self.block_size
for i in range(n_blocks):
for i in range(num_already_cached_blocks, n_blocks):
chunk = tokens[:(i + 1) * self.block_size]
hash = token_ids_to_hash(chunk)
if hash not in self.tokens_to_blocks:
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/inference/v2/ragged/ragged_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None:
Update the KV cache for the given sequence id.
"""
seq = self.get_sequence(uid)
self._block_map.extend(tokens, seq.all_block_ids())
num_full_blocks = tokens.numel() // self._kv_configs[0].block_size
if num_full_blocks > seq.num_prefix_cache_blocks:
self._block_map.extend(tokens, seq.all_block_ids(), seq.num_prefix_cache_blocks)
seq.num_prefix_cache_blocks = num_full_blocks

def increment_ref_count(self, block_ids: torch.Tensor) -> None:
for block_id in block_ids:
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/inference/v2/ragged/sequence_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class DSSequenceDescriptor(BaseSequenceDescriptor):
# are stored. Used on flush.
_tracking_id: int

_num_prefix_cache_blocks: int

def __init__(self,
tracking_id: int,
kv_cache_ids: Tuple[torch.Tensor, ...],
Expand Down Expand Up @@ -132,6 +134,8 @@ def __init__(self,
assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0]
assert len(kv_cache_ids.shape) == 2

self._num_prefix_cache_blocks = 0

@property
def seen_tokens(self) -> int:
"""
Expand Down Expand Up @@ -278,3 +282,20 @@ def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor],
to have the same shape.
"""
raise NotImplementedError("Partial KV-cache freeing is not yet supported.")

@property
def num_prefix_cache_blocks(self) -> int:
"""
The number of prefix cache blocks for the sequence.
"""
return self._num_prefix_cache_blocks

@num_prefix_cache_blocks.setter
def num_prefix_cache_blocks(self, num_blocks: int) -> None:
"""
Set the number of prefix cache blocks for the sequence.
Arguments:
num_blocks (int): The number of prefix cache blocks for the sequence.
"""
self._num_prefix_cache_blocks = num_blocks

0 comments on commit 0c8e0e6

Please sign in to comment.