-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a convergence issues in TP topology caused by incorrect grad_norm. #5411
Merged
Merged
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
287fa5e
fix grad norm for tp
inkcherry a7e8a7f
refine code
inkcherry ea41928
remove unnecessary clip_gradients fun
inkcherry e74b7ca
improve perf by loop-free implementations
inkcherry 79cc4ce
Modify the comments.
inkcherry 3ebed5e
update
inkcherry fc537b8
Merge remote-tracking branch 'master' into tp_grad_fix
inkcherry df976ca
refine comments
inkcherry a40263f
Merge branch 'master' into tp_grad_fix
conglongli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
|
||
import torch | ||
from deepspeed import comm as dist | ||
|
||
try: | ||
from torch._six import inf | ||
except ModuleNotFoundError: | ||
|
@@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): | |
return total_norm | ||
|
||
|
||
def get_grad_norm(parameters, norm_type=2, mpu=None): | ||
def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): | ||
"""Get grad norm of an iterable of parameters. | ||
|
||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and | ||
|
@@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): | |
single Tensor that will have gradients normalized | ||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | ||
infinity norm. | ||
|
||
grad_norm_mask (List[Tensor]): A list of Tensor, where | ||
each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. | ||
Returns: | ||
Total norm of the parameters (viewed as a single vector). | ||
""" | ||
|
@@ -415,18 +415,28 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): | |
total_norm = total_norm_cuda[0].item() | ||
else: | ||
total_norm = 0. | ||
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) | ||
for p in parameters: | ||
# Pipeline parallelism may replicate parameters. Avoid multi-counting. | ||
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: | ||
continue | ||
|
||
# Filter to avoid over-counting replicated tensors from tensor | ||
# model parallelism | ||
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): | ||
continue | ||
for idx, p in enumerate(parameters): | ||
if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: | ||
# Use grad_norm_mask to avoid redundant computation of flattened gradient norm | ||
# # including, Pipeline parallelism may replicate parameters. | ||
# # replicated tensors from tensor model parallelism | ||
|
||
# A loop-free implementation to create a mask tensor based on a range list, | ||
# which is logically equivalent to the following implementation. | ||
|
||
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) | ||
# #for mask_idx in grad_norm_mask[idx]: | ||
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True | ||
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), | ||
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) | ||
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) | ||
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), | ||
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] | ||
# assert torch.equal(mask_tensor_, mask_tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete this if no longer needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. deleted |
||
param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) | ||
|
||
param_norm = p.grad.data.float().norm(norm_type) | ||
else: | ||
param_norm = p.grad.data.float().norm(norm_type) | ||
total_norm += param_norm.item()**norm_type | ||
|
||
# Sum across all model parallel GPUs. | ||
|
@@ -814,25 +824,6 @@ def get_only_unique_item(items): | |
return unique_item | ||
|
||
|
||
def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): | ||
"""Clip the gradient of a list of parameters. | ||
Args: | ||
parameters: List of parameters whose .grad will be clipped. | ||
global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. | ||
mpu (optional): model parallelism unit. Defaults to None. | ||
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 | ||
Returns: | ||
float: the global gradient norm | ||
""" | ||
if global_grad_norm is None: | ||
global_grad_norm = get_grad_norm(parameters, mpu=mpu) | ||
clip_coef = max_norm / (global_grad_norm + eps) | ||
if clip_coef < 1: | ||
for p in parameters: | ||
p.grad.detach().mul_(clip_coef) | ||
return global_grad_norm | ||
|
||
|
||
conglongli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): | ||
"""Get norm of an iterable of tensors. | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean up this block of comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned up and kept some for readability