diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3ac6987e9c22..9b7645261eae 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -821,10 +821,14 @@ def _create_fp32_partitions(self): for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] + ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) + ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) + ds_id = ds_id_begin + '_' + ds_id_end # a partition of the fp32 master weights that will be updated by this process if self._swappable_optimizer_subgroup(i): self.fp32_partitioned_groups_flat.append(torch.Tensor()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 @@ -861,11 +865,9 @@ def _create_fp32_partitions(self): else: self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( self.device).clone().float().detach()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it - ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) - ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) - self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end if len(swappable_fp32_tensors) > 0: self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors,