Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Nov 11, 2023
1 parent f191146 commit 58cf675
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def kv_cache_config(self) -> Tuple[KVCacheConfig, ...]:
cache_shape=cache_shape,
cache_dtype=self.activation_dtype,
max_blocks_per_allocation_group=max_blocks)
return (self._kv_cache_config,)
return (self._kv_cache_config, )

def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/inference/v2/ragged/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def __init__(self,
head_size = config.cache_shape[2]

alloc_shape = (num_caches, num_blocks, config.block_size, 2, num_heads, head_size)
inference_logger().info(f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.")
caches.append(torch.empty(alloc_shape,
dtype=config.cache_dtype,
inference_logger().info(
f"Allocating KV-cache {cache_group_id} with shape: {alloc_shape} consisting of {num_blocks} blocks.")
caches.append(torch.empty(alloc_shape, dtype=config.cache_dtype,
device=get_accelerator().current_device()))
allocators.append(BlockedAllocator(num_blocks))

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/inference/v2/ragged/ragged_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def _create_sequence(self, uid: int) -> DSSequenceDescriptor:
f"Unable to create tracking slot for sequence {uid} since the metadata buffers are full.")

seq_block_ids = tuple(all_block_ids[tracking_slot] for all_block_ids in self._all_block_ids)
seq_block_ids_shadow = tuple(all_block_ids_shadow[tracking_slot] for all_block_ids_shadow in
self._all_block_ids_shadow)
seq_block_ids_shadow = tuple(all_block_ids_shadow[tracking_slot]
for all_block_ids_shadow in self._all_block_ids_shadow)

self._seqs[uid] = DSSequenceDescriptor(tracking_slot,
seq_block_ids,
seq_block_ids_shadow,
max_context=self._config.max_context)
# TODO(cmikeh2): Debug call here might be unecessary and is potentially on critical path.
# TODO(cmikeh2): Debug call here might be unnecessary and is potentially on critical path.
logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.")
return self._seqs[uid]

Expand Down
12 changes: 8 additions & 4 deletions deepspeed/inference/v2/ragged/sequence_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ def __init__(self,
self._seen_tokens = 0
self._in_flight_tokens = 0

self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0] for kv_cache_ids_shadow in kv_cache_ids_shadow)
self._blocks_per_allocation_group = tuple(torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups)
self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0]
for kv_cache_ids_shadow in kv_cache_ids_shadow)
self._blocks_per_allocation_group = tuple(
torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups)

for cache_group, kv_cache_ids in enumerate(kv_cache_ids):
assert self._num_allocation_groups[cache_group] == kv_cache_ids.shape[0]
Expand Down Expand Up @@ -202,7 +204,8 @@ def all_block_ids(self, cache_group: int = 0) -> torch.Tensor:
cache_group (int): The cache group to query.
"""
block_ids = []
for allocation_group, num_blocks in zip(self._kv_cache_ids[cache_group], self._blocks_per_allocation_group[cache_group]):
for allocation_group, num_blocks in zip(self._kv_cache_ids[cache_group],
self._blocks_per_allocation_group[cache_group]):
block_ids.append(allocation_group[:num_blocks])
return torch.cat(block_ids)

Expand Down Expand Up @@ -239,7 +242,8 @@ def extend_kv_cache(self, new_ids: Union[List[torch.IntTensor], torch.IntTensor]
new_ids = [new_ids]

if len(new_ids) != self._num_allocation_groups[cache_group]:
raise ValueError(f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups[cache_group]}")
raise ValueError(
f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups[cache_group]}")

for group_id, new_group_ids in enumerate(new_ids):
new_blocks = new_group_ids.numel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def build_batch_and_manager(
memory_config=memory_config)

batch = RaggedBatchWrapper(config)
state_manager = DSStateManager(config, (kv_config,))
state_manager = DSStateManager(config, (kv_config, ))

# At the beginning of operation, the design of the allocator is such that it will return
# linear blocks of memory. The following will "warm up" the allocator so that we can be
Expand Down

0 comments on commit 58cf675

Please sign in to comment.