Skip to content

Commit

Permalink
Fix missing ds_id bug (#5824)
Browse files Browse the repository at this point in the history
Fix #5495 - Fix missing ds_id bug by copying solution from #5193 (credit
to @getinglxf)

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
tjruwase and loadams authored Aug 14, 2024
1 parent 051c993 commit e3177de
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,14 @@ def _create_fp32_partitions(self):

for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
num_elements = self.fp16_partitioned_groups_flat_numel[i]
ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0])
ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1])
ds_id = ds_id_begin + '_' + ds_id_end

# a partition of the fp32 master weights that will be updated by this process
if self._swappable_optimizer_subgroup(i):
self.fp32_partitioned_groups_flat.append(torch.Tensor())
self.fp32_partitioned_groups_flat[i].ds_id = ds_id
nvme_memory_usage += (fp32_element_size * num_elements)
num_swappable_partitions += 1

Expand Down Expand Up @@ -861,11 +865,9 @@ def _create_fp32_partitions(self):
else:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
self.fp32_partitioned_groups_flat[i].ds_id = ds_id

self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0])
ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1])
self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end

if len(swappable_fp32_tensors) > 0:
self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors,
Expand Down

0 comments on commit e3177de

Please sign in to comment.