diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b0a3ab778f2a..65d1c5ace08f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -215,14 +215,12 @@ def __init__( self.module = module self.elastic_checkpoint = elastic_checkpoint - self.inf_or_nan_tracker: Tensor = torch.zeros(1, - dtype=torch.bool, - device=get_accelerator().current_device_name(), - requires_grad=False) + self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu + + self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False) self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) - self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator( ).Stream() if overlap_comm else get_accelerator().default_stream() @@ -2148,7 +2146,8 @@ def has_overflow(self, partition_gradients=True): self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 - overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) + overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to( + torch.uint8) self.inf_or_nan_tracker.zero_() if not get_accelerator().resolves_data_dependency():