diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7a45a305957e..2f4e37a8e5cc 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1339,7 +1339,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - err = torch.tensor(-1.0, device=self.current_device, dtype=torch.float) + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) total_norm = err.where(inf_or_nan,total_norm) return total_norm