From e3177de666f9aca1a9e16c57ffe38f22ebf54d38 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 14 Aug 2024 07:24:15 -0400 Subject: [PATCH] Fix missing ds_id bug (#5824) Fix #5495 - Fix missing ds_id bug by copying solution from #5193 (credit to @getinglxf) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/stage3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3ac6987e9c22..9b7645261eae 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -821,10 +821,14 @@ def _create_fp32_partitions(self): for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] + ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) + ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) + ds_id = ds_id_begin + '_' + ds_id_end # a partition of the fp32 master weights that will be updated by this process if self._swappable_optimizer_subgroup(i): self.fp32_partitioned_groups_flat.append(torch.Tensor()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 @@ -861,11 +865,9 @@ def _create_fp32_partitions(self): else: self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( self.device).clone().float().detach()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it - ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) - ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) - self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end if len(swappable_fp32_tensors) > 0: self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors,