Skip to content

Commit

Permalink
fix offloading of lp grad
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Aug 20, 2024
1 parent 3f8179d commit 37ffa02
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,7 @@ def needs_offload(target):
torch.empty_like(self.grad_partitions_flat_buffer, device=device))
self.lp_grad_partitions_flat_pin_buffers.copy_(self.grad_partitions_flat_buffer,
non_blocking=non_blocking)
self.grad_partitions_flat_buffer.data = self.lp_grad_partitions_flat_pin_buffers
else:
self.grad_partitions_flat_buffer.data = self.grad_partitions_flat_buffer.data.to(device)
self.averaged_gradients = {}
Expand Down Expand Up @@ -2919,6 +2920,7 @@ def offload_states_back(self, non_blocking: bool = False):

# contiguous bucket
if OffloadStateTypeEnum.contiguous_grad_buffer in self.offloaded_states:
print(f"loading contiguous_grad_buffer")
self.__ipg_bucket_flat_buffer = torch.empty_like(self.grad_buffer_meta, device=device)
# self.__ipg_bucket_flat_buffer.data = self.__ipg_bucket_flat_buffer.data.to(device)
self.offloaded_states.remove(OffloadStateTypeEnum.contiguous_grad_buffer)
Expand Down

0 comments on commit 37ffa02

Please sign in to comment.