diff --git a/deepspeed/inference/v2/ragged/prefix_block_map.py b/deepspeed/inference/v2/ragged/prefix_block_map.py index 85a2e83140f6..d59c750366b2 100644 --- a/deepspeed/inference/v2/ragged/prefix_block_map.py +++ b/deepspeed/inference/v2/ragged/prefix_block_map.py @@ -38,9 +38,9 @@ def lookup(self, tokens: torch.Tensor) -> torch.Tensor: break return cached_blocks - def extend(self, tokens: torch.Tensor, new_block_ids: List[int]) -> None: + def extend(self, tokens: torch.Tensor, new_block_ids: List[int], num_already_cached_blocks: int) -> None: n_blocks = len(tokens) // self.block_size - for i in range(n_blocks): + for i in range(num_already_cached_blocks, n_blocks): chunk = tokens[:(i + 1) * self.block_size] hash = token_ids_to_hash(chunk) if hash not in self.tokens_to_blocks: diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 3de4059283eb..fa07f663d54a 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -191,7 +191,10 @@ def update_cache(self, uid: int, tokens: torch.Tensor) -> None: Update the KV cache for the given sequence id. """ seq = self.get_sequence(uid) - self._block_map.extend(tokens, seq.all_block_ids()) + num_full_blocks = tokens.numel() // self._kv_configs[0].block_size + if num_full_blocks > seq.num_prefix_cache_blocks: + self._block_map.extend(tokens, seq.all_block_ids(), seq.num_prefix_cache_blocks) + seq.num_prefix_cache_blocks = num_full_blocks def increment_ref_count(self, block_ids: torch.Tensor) -> None: for block_id in block_ids: diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py index 6b9f65255eec..25c1072ff042 100644 --- a/deepspeed/inference/v2/ragged/sequence_descriptor.py +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -95,6 +95,8 @@ class DSSequenceDescriptor(BaseSequenceDescriptor): # are stored. Used on flush. _tracking_id: int + _num_prefix_cache_blocks: int + def __init__(self, tracking_id: int, kv_cache_ids: Tuple[torch.Tensor, ...], @@ -132,6 +134,8 @@ def __init__(self, assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0] assert len(kv_cache_ids.shape) == 2 + self._num_prefix_cache_blocks = 0 + @property def seen_tokens(self) -> int: """ @@ -278,3 +282,20 @@ def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor], to have the same shape. """ raise NotImplementedError("Partial KV-cache freeing is not yet supported.") + + @property + def num_prefix_cache_blocks(self) -> int: + """ + The number of prefix cache blocks for the sequence. + """ + return self._num_prefix_cache_blocks + + @num_prefix_cache_blocks.setter + def num_prefix_cache_blocks(self, num_blocks: int) -> None: + """ + Set the number of prefix cache blocks for the sequence. + + Arguments: + num_blocks (int): The number of prefix cache blocks for the sequence. + """ + self._num_prefix_cache_blocks = num_blocks