Skip to content

Commit

Permalink
optimize grad_norm calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
mmhab committed Oct 2, 2023
1 parent 0636c74 commit a56838a
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
param_id = self.get_param_id(p)
if param_id in self.norm_for_param_grads.keys():
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
total_norm += param_norm**2

# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
Expand All @@ -1288,10 +1288,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)
total_norm = total_norm_cuda[0]**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

err = torch.tensor(-1.0, device=self.current_device, dtype=torch.float)
total_norm = err.where(inf_or_nan,total_norm)

return total_norm

Expand Down Expand Up @@ -1622,7 +1626,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
total_norm = total_norm_cuda[0]
else:
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
Expand All @@ -1643,10 +1647,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda.item()**(1. / norm_type)
total_norm = total_norm_cuda**(1. / norm_type)

norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
err = torch.tensor(-1.0, device=self.current_device, dtype=torch.float)
total_norm = err.where(inf_or_nan,total_norm)

return total_norm

Expand Down

0 comments on commit a56838a

Please sign in to comment.