diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index af8050c4646a..bf1693307ea7 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,15 +9,16 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers +from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist from deepspeed.utils.torch import required_torch_version from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator from deepspeed.moe.utils import is_moe_param_group +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -64,6 +65,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self.flatten_grad_norm_mask_list = [] + self.has_executed_step = False self._global_grad_norm = 0. # loop to deal with groups @@ -206,6 +209,40 @@ def override_loss_scale(self, loss_scale): self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): + # for filtering replicated tensors from tensor + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + return True + if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): + return True + + def _get_norm_mask_idx(self, group): + """The function preserves the parallel information for norm + from unflattened gradients. + + Args: + group (Iterable[Tensor] ): params group + + Returns: + torch.Tensor: A 2D tensor containing index ranges for each group, + where each row represents a [start index, end index]. + """ + group_mask_idx_list = [] + grad_flat_st_idx = 0 + grad_flat_en_idx = 0 + + for p in group: + grad_flat_en_idx = grad_flat_st_idx + p.numel() + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + def step(self, closure=None): """ Not supporting closure. @@ -251,23 +288,32 @@ def step(self, closure=None): for p in group ])) - for p in group: - p.grad = None - self.fp32_groups_flat[i].grad = grads_groups_flat[i] param_group = self.optimizer.param_groups[i] + + # split expert and non_expert grads for norm if self.has_moe_layers and is_moe_param_group(param_group): if param_group['name'] not in expert_grads_for_norm: expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) else: + # retrieves the required mask for calculating the norm of flat_grad + # perform this collect operation only once + if not self.has_executed_step: + cur_flat_grad_norm_mask = self._get_norm_mask_idx(group) + self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) - self.timers(COMPUTE_NORM_TIMER).start() + for p in group: + p.grad = None - all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) + self.timers(COMPUTE_NORM_TIMER).start() - self.timers(COMPUTE_NORM_TIMER).stop() + all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, + mpu=self.mpu, + grad_norm_mask=self.flatten_grad_norm_mask_list) if self.has_moe_layers: all_groups_norm = get_norm_with_moe_layers(all_groups_norm, @@ -276,6 +322,7 @@ def step(self, closure=None): norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + self.timers(COMPUTE_NORM_TIMER).stop() # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale @@ -298,7 +345,7 @@ def step(self, closure=None): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) - + self.has_executed_step = True self.timers(UPDATE_FP16_TIMER).stop() self.timers.log(STEP_TIMERS) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c55f8a0e2995..7744b2ee8b98 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -17,7 +17,6 @@ import torch from deepspeed import comm as dist - try: from torch._six import inf except ModuleNotFoundError: @@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): return total_norm -def get_grad_norm(parameters, norm_type=2, mpu=None): +def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - + grad_norm_mask (List[Tensor]): A list of Tensor, where + each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. Returns: Total norm of the parameters (viewed as a single vector). """ @@ -415,18 +415,25 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. - tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) - for p in parameters: - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): - continue + for idx, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: + + # A loop-free implementation to create a mask tensor based on a range list + # which is logically equivalent to the following implementation. + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) + mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) - param_norm = p.grad.data.float().norm(norm_type) + else: + param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. @@ -814,25 +821,6 @@ def get_only_unique_item(items): return unique_item -def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): - """Clip the gradient of a list of parameters. - Args: - parameters: List of parameters whose .grad will be clipped. - global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. - mpu (optional): model parallelism unit. Defaults to None. - eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 - Returns: - float: the global gradient norm - """ - if global_grad_norm is None: - global_grad_norm = get_grad_norm(parameters, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - return global_grad_norm - - def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors.