Skip to content

Commit

Permalink
optimize grad_norm calculation in stage3.py (#4436)
Browse files Browse the repository at this point in the history
reduce the synchronization between the device and the host by removing
.item() from the loops that calculate the total norm.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Shaden Smith <[email protected]>
  • Loading branch information
5 people authored Jan 2, 2024
1 parent 4034205 commit ea0d811
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
param_id = self.get_param_id(p)
if param_id in self.norm_for_param_grads.keys():
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
total_norm += param_norm**2

# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
Expand All @@ -1337,10 +1337,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)
total_norm = total_norm_cuda[0]**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm

Expand Down Expand Up @@ -1669,7 +1673,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
total_norm = total_norm_cuda[0]
else:
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
Expand All @@ -1690,10 +1694,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda.item()**(1. / norm_type)
total_norm = total_norm_cuda**(1. / norm_type)

norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm

Expand Down

0 comments on commit ea0d811

Please sign in to comment.