Skip to content

Commit

Permalink
Merge branch 'master' into fix-z3-sp-arg
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Aug 14, 2024
2 parents 6f2d491 + e3177de commit 875c0bc
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,14 @@ def _create_fp32_partitions(self):

for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
num_elements = self.fp16_partitioned_groups_flat_numel[i]
ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0])
ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1])
ds_id = ds_id_begin + '_' + ds_id_end

# a partition of the fp32 master weights that will be updated by this process
if self._swappable_optimizer_subgroup(i):
self.fp32_partitioned_groups_flat.append(torch.Tensor())
self.fp32_partitioned_groups_flat[i].ds_id = ds_id
nvme_memory_usage += (fp32_element_size * num_elements)
num_swappable_partitions += 1

Expand Down Expand Up @@ -861,11 +865,9 @@ def _create_fp32_partitions(self):
else:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
self.fp32_partitioned_groups_flat[i].ds_id = ds_id

self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0])
ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1])
self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end

if len(swappable_fp32_tensors) > 0:
self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors,
Expand Down

0 comments on commit 875c0bc

Please sign in to comment.