From 58cf675f1ba27228a95adf72124888be9027371a Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Sat, 11 Nov 2023 04:47:17 +0000 Subject: [PATCH] Formatting --- .../inference_transformer_base.py | 2 +- deepspeed/inference/v2/ragged/kv_cache.py | 6 +++--- deepspeed/inference/v2/ragged/ragged_manager.py | 6 +++--- deepspeed/inference/v2/ragged/sequence_descriptor.py | 12 ++++++++---- .../v2/kernels/ragged_ops/ragged_testing_utils.py | 2 +- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index b4a2ff21ced3..8ea26f21de4e 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -392,7 +392,7 @@ def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]: cache_shape=cache_shape, cache_dtype=self.activation_dtype, max_blocks_per_allocation_group=max_blocks) - return (self._kv_cache_config,) + return (self._kv_cache_config, ) def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None: """ diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py index fd71004dd988..50da350b6506 100644 --- a/deepspeed/inference/v2/ragged/kv_cache.py +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -132,9 +132,9 @@ def __init__(self, head_size = config.cache_shape[2] alloc_shape = (num_caches, num_blocks, config.block_size, 2, num_heads, head_size) - inference_logger().info(f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") - caches.append(torch.empty(alloc_shape, - dtype=config.cache_dtype, + inference_logger().info( + f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.") + caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype, device=get_accelerator().current_device())) allocators.append(BlockedAllocator(num_blocks)) diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 103dea8c7929..4c3faef12173 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -158,14 +158,14 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor: f"Unable to create tracking slot for sequence {uid} since the metadata buffers are full.") seq_block_ids = tuple(all_block_ids[tracking_slot] for all_block_ids in self._all_block_ids) - seq_block_ids_shadow = tuple(all_block_ids_shadow[tracking_slot] for all_block_ids_shadow in - self._all_block_ids_shadow) + seq_block_ids_shadow = tuple(all_block_ids_shadow[tracking_slot] + for all_block_ids_shadow in self._all_block_ids_shadow) self._seqs[uid] = DSSequenceDescriptor(tracking_slot, seq_block_ids, seq_block_ids_shadow, max_context=self._config.max_context) - # TODO(cmikeh2): Debug call here might be unecessary and is potentially on critical path. + # TODO(cmikeh2): Debug call here might be unnecessary and is potentially on critical path. logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") return self._seqs[uid] diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py index 73104d7154c4..7b3c3dcb0f11 100644 --- a/deepspeed/inference/v2/ragged/sequence_descriptor.py +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -123,8 +123,10 @@ def __init__(self, self._seen_tokens = 0 self._in_flight_tokens = 0 - self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0] for kv_cache_ids_shadow in kv_cache_ids_shadow) - self._blocks_per_allocation_group = tuple(torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups) + self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0] + for kv_cache_ids_shadow in kv_cache_ids_shadow) + self._blocks_per_allocation_group = tuple( + torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups) for cache_group, kv_cache_ids in enumerate(kv_cache_ids): assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0] @@ -202,7 +204,8 @@ def all_block_ids(self, cache_group: int = 0) -> torch.Tensor: cache_group (int): The cache group to query. """ block_ids = [] - for allocation_group, num_blocks in zip(self._kv_cache_ids[cache_group], self._blocks_per_allocation_group[cache_group]): + for allocation_group, num_blocks in zip(self._kv_cache_ids[cache_group], + self._blocks_per_allocation_group[cache_group]): block_ids.append(allocation_group[:num_blocks]) return torch.cat(block_ids) @@ -239,7 +242,8 @@ def extend_kv_cache(self, new_ids: Union[List[torch.IntTensor], torch.IntTensor] new_ids = [new_ids] if len(new_ids) != self._num_allocation_groups[cache_group]: - raise ValueError(f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups[cache_group]}") + raise ValueError( + f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups[cache_group]}") for group_id, new_group_ids in enumerate(new_ids): new_blocks = new_group_ids.numel() diff --git a/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py index e553c81b1f5a..be7454fee4aa 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py @@ -135,7 +135,7 @@ def build_batch_and_manager( memory_config=memory_config) batch = RaggedBatchWrapper(config) - state_manager = DSStateManager(config, (kv_config,)) + state_manager = DSStateManager(config, (kv_config, )) # At the beginning of operation, the design of the allocator is such that it will return # linear blocks of memory. The following will "warm up" the allocator so that we can be