Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a convergence issues in TP topology caused by incorrect grad_norm. #5411

Merged
merged 9 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers
from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger, log_dist
from deepspeed.utils.torch import required_torch_version
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
from deepspeed.accelerator import get_accelerator
from deepspeed.moe.utils import is_moe_param_group
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank

OVERFLOW_CHECK_TIMER = 'overflow_check'
COMPUTE_NORM_TIMER = 'compute_norm'
Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self.flatten_grad_norm_mask_list = []
self.has_executed_step = False
self._global_grad_norm = 0.

# loop to deal with groups
Expand Down Expand Up @@ -206,6 +209,40 @@ def override_loss_scale(self, loss_scale):
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale

def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank):
# for filtering replicated tensors from tensor
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
return True
if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p):
return True

def _get_norm_mask_idx(self, group):
"""The function preserves the parallel information for norm
from unflattened gradients.

Args:
group (Iterable[Tensor] ): params group

Returns:
torch.Tensor: A 2D tensor containing index ranges for each group,
where each row represents a [start index, end index].
"""
group_mask_idx_list = []
grad_flat_st_idx = 0
grad_flat_en_idx = 0

for p in group:
grad_flat_en_idx = grad_flat_st_idx + p.numel()
if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)):
# merge range
if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]:
group_mask_idx_list[-1][-1] = grad_flat_en_idx
else:
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx

return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())

def step(self, closure=None):
"""
Not supporting closure.
Expand Down Expand Up @@ -251,23 +288,32 @@ def step(self, closure=None):
for p in group
]))

for p in group:
p.grad = None

self.fp32_groups_flat[i].grad = grads_groups_flat[i]
param_group = self.optimizer.param_groups[i]

# split expert and non_expert grads for norm
if self.has_moe_layers and is_moe_param_group(param_group):
if param_group['name'] not in expert_grads_for_norm:
expert_grads_for_norm[param_group['name']] = []

expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i])
else:
# retrieves the required mask for calculating the norm of flat_grad
# perform this collect operation only once
if not self.has_executed_step:
cur_flat_grad_norm_mask = self._get_norm_mask_idx(group)
self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask)

non_experts_grads_for_norm.append(self.fp32_groups_flat[i])

self.timers(COMPUTE_NORM_TIMER).start()
for p in group:
p.grad = None

all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu)
self.timers(COMPUTE_NORM_TIMER).start()

self.timers(COMPUTE_NORM_TIMER).stop()
all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm,
mpu=self.mpu,
grad_norm_mask=self.flatten_grad_norm_mask_list)

if self.has_moe_layers:
all_groups_norm = get_norm_with_moe_layers(all_groups_norm,
Expand All @@ -276,6 +322,7 @@ def step(self, closure=None):
norm_type=self.norm_type)

scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
self.timers(COMPUTE_NORM_TIMER).stop()

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
Expand All @@ -298,7 +345,7 @@ def step(self, closure=None):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)

self.has_executed_step = True
self.timers(UPDATE_FP16_TIMER).stop()

self.timers.log(STEP_TIMERS)
Expand Down
57 changes: 24 additions & 33 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from deepspeed import comm as dist

try:
from torch._six import inf
except ModuleNotFoundError:
Expand Down Expand Up @@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
return total_norm


def get_grad_norm(parameters, norm_type=2, mpu=None):
def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None):
"""Get grad norm of an iterable of parameters.

This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
Expand All @@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
single Tensor that will have gradients normalized
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.

grad_norm_mask (List[Tensor]): A list of Tensor, where
each Tensor is a 2D Tensor containing ranges of [start_index, end_index].
Returns:
Total norm of the parameters (viewed as a single vector).
"""
Expand All @@ -415,18 +415,28 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
for p in parameters:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue

# Filter to avoid over-counting replicated tensors from tensor
# model parallelism
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
continue
for idx, p in enumerate(parameters):
if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0:
# Use grad_norm_mask to avoid redundant computation of flattened gradient norm
# # including, Pipeline parallelism may replicate parameters.
# # replicated tensors from tensor model parallelism

# A loop-free implementation to create a mask tensor based on a range list,
# which is logically equivalent to the following implementation.

# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# #for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up this block of comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up and kept some for readability

cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]
# assert torch.equal(mask_tensor_, mask_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete this if no longer needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type)

param_norm = p.grad.data.float().norm(norm_type)
else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type

# Sum across all model parallel GPUs.
Expand Down Expand Up @@ -814,25 +824,6 @@ def get_only_unique_item(items):
return unique_item


def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6):
"""Clip the gradient of a list of parameters.
Args:
parameters: List of parameters whose .grad will be clipped.
global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None.
mpu (optional): model parallelism unit. Defaults to None.
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
Returns:
float: the global gradient norm
"""
if global_grad_norm is None:
global_grad_norm = get_grad_norm(parameters, mpu=mpu)
clip_coef = max_norm / (global_grad_norm + eps)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return global_grad_norm


conglongli marked this conversation as resolved.
Show resolved Hide resolved
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
"""Get norm of an iterable of tensors.

Expand Down
Loading