Skip to content

Commit

Permalink
refine comments
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Apr 16, 2024
1 parent fc537b8 commit df976ca
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,23 +416,20 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
else:
total_norm = 0.
for idx, p in enumerate(parameters):
# Use grad_norm_mask to avoid redundant computation of flattened gradient norm
if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0:
# Use grad_norm_mask to avoid redundant computation of flattened gradient norm
# # including, Pipeline parallelism may replicate parameters.
# # replicated tensors from tensor model parallelism

# A loop-free implementation to create a mask tensor based on a range list,
# A loop-free implementation to create a mask tensor based on a range list
# which is logically equivalent to the following implementation.

# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# #for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]
# assert torch.equal(mask_tensor_, mask_tensor)

param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type)

else:
Expand Down

0 comments on commit df976ca

Please sign in to comment.