From ea0d81143c6ba0801828919ea53888843d0fb19f Mon Sep 17 00:00:00 2001 From: mmhab <132277730+mmhab@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:19:08 +0200 Subject: [PATCH] optimize grad_norm calculation in stage3.py (#4436) reduce the synchronization between the device and the host by removing .item() from the loops that calculate the total norm. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt Co-authored-by: Michael Wyatt Co-authored-by: Shaden Smith --- deepspeed/runtime/zero/stage3.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index fa4e64faf5a5..ce4137028195 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1328,7 +1328,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): param_id = self.get_param_id(p) if param_id in self.norm_for_param_grads.keys(): param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + total_norm += param_norm**2 # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) @@ -1337,10 +1337,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0]**(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm @@ -1669,7 +1673,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() + total_norm = total_norm_cuda[0] else: # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") @@ -1690,10 +1694,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda.item()**(1. / norm_type) + total_norm = total_norm_cuda**(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm