Skip to content

Commit

Permalink
Move inf_or_nan_tracker to cpu for cpu offload (#5826)
Browse files Browse the repository at this point in the history
Must use the same device as grad_partitions_flat_buffer

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent 9a3ede7 commit 4d4ff0e
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,12 @@ def __init__(
self.module = module
self.elastic_checkpoint = elastic_checkpoint

self.inf_or_nan_tracker: Tensor = torch.zeros(1,
dtype=torch.bool,
device=get_accelerator().current_device_name(),
requires_grad=False)
self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu

self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False)

self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam)

self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu
### streams used for overlapping computation with communication
self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator(
).Stream() if overlap_comm else get_accelerator().default_stream()
Expand Down Expand Up @@ -2148,7 +2146,8 @@ def has_overflow(self, partition_gradients=True):
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to(
torch.uint8)
self.inf_or_nan_tracker.zero_()

if not get_accelerator().resolves_data_dependency():
Expand Down

0 comments on commit 4d4ff0e

Please sign in to comment.