From 287fa5eb98d63fbf1ef4598036ff8d371a9ff953 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 11 Apr 2024 16:54:08 +0800 Subject: [PATCH 1/7] fix grad norm for tp --- deepspeed/runtime/fp16/fused_optimizer.py | 40 ++++++++++++++++++++--- deepspeed/runtime/utils.py | 28 +++++++++------- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 9ed250252e17..ddf1cc4d2eca 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,14 +9,15 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers +from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers, is_model_parallel_parameter from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator from deepspeed.moe.utils import is_moe_param_group +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -205,6 +206,29 @@ def override_loss_scale(self, loss_scale): self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): + + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + return True + # Filter to avoid over-counting replicated tensors from tensor + # model parallelism + if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): + return True + + def _get_flat_grad_norm_mask_idx(self, group): + group_mask_idx_list = [] + grad_flat_st_idx = 0 + grad_flat_en_idx = 0 + + for p in group: + grad_flat_en_idx = grad_flat_st_idx + p.numel() + if p.grad is None or self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + group_mask_idx_list.append((grad_flat_st_idx, grad_flat_en_idx)) + else: + grad_flat_st_idx = grad_flat_en_idx + p.grad = None + return group_mask_idx_list + def step(self, closure=None): """ Not supporting closure. @@ -237,6 +261,7 @@ def step(self, closure=None): return self.overflow grads_groups_flat = [] + flatten_grad_norm_mask_list = [] non_experts_grads_for_norm = [] expert_grads_for_norm = {} assert len(self.fp16_groups) == len(self.optimizer.param_groups) @@ -250,11 +275,14 @@ def step(self, closure=None): for p in group ])) - for p in group: - p.grad = None + # retrieves the required mask for calculating the norm of flat_grad + cur_flat_grad_norm_mask = self._get_flat_grad_norm_mask_idx(group) + flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) self.fp32_groups_flat[i].grad = grads_groups_flat[i] param_group = self.optimizer.param_groups[i] + + # split expert and non_expert grads for norm if self.has_moe_layers and is_moe_param_group(param_group): if param_group['name'] not in expert_grads_for_norm: expert_grads_for_norm[param_group['name']] = [] @@ -264,7 +292,9 @@ def step(self, closure=None): self.timers(COMPUTE_NORM_TIMER).start() - all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) + all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, + mpu=self.mpu, + grad_norm_mask=flatten_grad_norm_mask_list) self.timers(COMPUTE_NORM_TIMER).stop() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 9d561f7271eb..255dfcadf6d5 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -386,7 +386,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): return total_norm -def get_grad_norm(parameters, norm_type=2, mpu=None): +def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -398,6 +398,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + grad_norm_mask (List[List[Tuple[int, int]]]): A list of lists, where + each inner list contains tuples(start_idx, end_idx) of a flattened grad Returns: Total norm of the parameters (viewed as a single vector). @@ -416,18 +418,20 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. - tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) - for p in parameters: - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): - continue + for id, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + # # including, Pipeline parallelism may replicate parameters. + # # replicated tensors from tensor model parallelism + mask_tensor = torch.ones_like(p, device=p.device, dtype=bool) + for mask_idx in grad_norm_mask[id]: + mask_tensor[mask_idx[0]:mask_idx[1]] = 0 + + # assert torch.allclose(tmp_mask,grad_norm_mask[id]) + if grad_norm_mask is not None: + param_norm = (p.grad.data * mask_tensor).float().norm(norm_type) + else: + param_norm = p.grad.data.float().norm(norm_type) - param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. From a7e8a7fe4ec614e2b110b626d80eb8214332798e Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 12 Apr 2024 16:53:37 +0800 Subject: [PATCH 2/7] refine code --- deepspeed/runtime/fp16/fused_optimizer.py | 21 ++++++++++++++------- deepspeed/runtime/utils.py | 19 +++++++++---------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index ddf1cc4d2eca..c6dbe7f7931d 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -207,15 +207,23 @@ def override_loss_scale(self, loss_scale): self.external_loss_scale = loss_scale def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): - + # for filtering replicated tensors from tensor if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: return True - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): return True - def _get_flat_grad_norm_mask_idx(self, group): + def _clear_grads_and_get_norm_mask_list(self, group): + """The function preserves the parallel information for norm + from unflattened gradients and clears the unflattened gradients. + + Args: + group (Iterable[Tensor] ): params group + + Returns: + List[Tuple[int, int]: list of indices to avoid redundant norm computation + for the flattened gradient associated with this group + """ group_mask_idx_list = [] grad_flat_st_idx = 0 grad_flat_en_idx = 0 @@ -276,7 +284,7 @@ def step(self, closure=None): ])) # retrieves the required mask for calculating the norm of flat_grad - cur_flat_grad_norm_mask = self._get_flat_grad_norm_mask_idx(group) + cur_flat_grad_norm_mask = self._clear_grads_and_get_norm_mask_list(group) flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) self.fp32_groups_flat[i].grad = grads_groups_flat[i] @@ -296,8 +304,6 @@ def step(self, closure=None): mpu=self.mpu, grad_norm_mask=flatten_grad_norm_mask_list) - self.timers(COMPUTE_NORM_TIMER).stop() - if self.has_moe_layers: all_groups_norm = get_norm_with_moe_layers(all_groups_norm, mpu=self.mpu, @@ -305,6 +311,7 @@ def step(self, closure=None): norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + self.timers(COMPUTE_NORM_TIMER).stop() # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 255dfcadf6d5..3b76bdaeee4c 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -419,16 +419,15 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No else: total_norm = 0. for id, p in enumerate(parameters): - # Use grad_norm_mask to avoid redundant computation of flattened gradient norm - # # including, Pipeline parallelism may replicate parameters. - # # replicated tensors from tensor model parallelism - mask_tensor = torch.ones_like(p, device=p.device, dtype=bool) - for mask_idx in grad_norm_mask[id]: - mask_tensor[mask_idx[0]:mask_idx[1]] = 0 - - # assert torch.allclose(tmp_mask,grad_norm_mask[id]) - if grad_norm_mask is not None: - param_norm = (p.grad.data * mask_tensor).float().norm(norm_type) + + if grad_norm_mask is not None and len(grad_norm_mask[id]) > 0: + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + # # including, Pipeline parallelism may replicate parameters. + # # replicated tensors from tensor model parallelism + mask_tensor = torch.ones_like(p, device=p.device, dtype=bool) + for mask_idx in grad_norm_mask[id]: + mask_tensor[mask_idx[0]:mask_idx[1]] = 0 + param_norm = (p.grad.data * mask_tensor).float().norm(norm_type) else: param_norm = p.grad.data.float().norm(norm_type) From ea41928e65c06dff6373b8f7420ff5b9cd557e14 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 12 Apr 2024 17:20:30 +0800 Subject: [PATCH 3/7] remove unnecessary clip_gradients fun --- deepspeed/runtime/utils.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 3b76bdaeee4c..578af4161024 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -418,14 +418,14 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No total_norm = total_norm_cuda[0].item() else: total_norm = 0. - for id, p in enumerate(parameters): + for idx, p in enumerate(parameters): - if grad_norm_mask is not None and len(grad_norm_mask[id]) > 0: + if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: # Use grad_norm_mask to avoid redundant computation of flattened gradient norm # # including, Pipeline parallelism may replicate parameters. # # replicated tensors from tensor model parallelism mask_tensor = torch.ones_like(p, device=p.device, dtype=bool) - for mask_idx in grad_norm_mask[id]: + for mask_idx in grad_norm_mask[idx]: mask_tensor[mask_idx[0]:mask_idx[1]] = 0 param_norm = (p.grad.data * mask_tensor).float().norm(norm_type) else: @@ -818,25 +818,6 @@ def get_only_unique_item(items): return unique_item -def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): - """Clip the gradient of a list of parameters. - Args: - parameters: List of parameters whose .grad will be clipped. - global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. - mpu (optional): model parallelism unit. Defaults to None. - eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 - Returns: - float: the global gradient norm - """ - if global_grad_norm is None: - global_grad_norm = get_grad_norm(parameters, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - return global_grad_norm - - def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. From e74b7ca686a747def0d13ceb5c3a5692b99f7cb4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 15 Apr 2024 10:25:07 +0800 Subject: [PATCH 4/7] improve perf by loop-free implementations --- deepspeed/runtime/fp16/fused_optimizer.py | 34 +++++++++++++---------- deepspeed/runtime/utils.py | 21 +++++++++----- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index c6dbe7f7931d..77824ef32149 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -64,6 +64,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self.flatten_grad_norm_mask_list = [] + self.has_executed_step = False self._global_grad_norm = 0. # loop to deal with groups @@ -213,16 +215,16 @@ def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): return True - def _clear_grads_and_get_norm_mask_list(self, group): + def _get_norm_mask_idx(self, group): """The function preserves the parallel information for norm - from unflattened gradients and clears the unflattened gradients. + from unflattened gradients. Args: group (Iterable[Tensor] ): params group Returns: - List[Tuple[int, int]: list of indices to avoid redundant norm computation - for the flattened gradient associated with this group + torch.Tensor: A 2D tensor containing index ranges for each group, + where each row represents a [start index, end index]. """ group_mask_idx_list = [] grad_flat_st_idx = 0 @@ -231,11 +233,10 @@ def _clear_grads_and_get_norm_mask_list(self, group): for p in group: grad_flat_en_idx = grad_flat_st_idx + p.numel() if p.grad is None or self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): - group_mask_idx_list.append((grad_flat_st_idx, grad_flat_en_idx)) + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) else: grad_flat_st_idx = grad_flat_en_idx - p.grad = None - return group_mask_idx_list + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) def step(self, closure=None): """ @@ -269,7 +270,6 @@ def step(self, closure=None): return self.overflow grads_groups_flat = [] - flatten_grad_norm_mask_list = [] non_experts_grads_for_norm = [] expert_grads_for_norm = {} assert len(self.fp16_groups) == len(self.optimizer.param_groups) @@ -283,10 +283,6 @@ def step(self, closure=None): for p in group ])) - # retrieves the required mask for calculating the norm of flat_grad - cur_flat_grad_norm_mask = self._clear_grads_and_get_norm_mask_list(group) - flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) - self.fp32_groups_flat[i].grad = grads_groups_flat[i] param_group = self.optimizer.param_groups[i] @@ -294,15 +290,25 @@ def step(self, closure=None): if self.has_moe_layers and is_moe_param_group(param_group): if param_group['name'] not in expert_grads_for_norm: expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) else: + # retrieves the required mask for calculating the norm of flat_grad + # perform this collect operation only once + if not self.has_executed_step: + cur_flat_grad_norm_mask = self._get_norm_mask_idx(group) + self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) + for p in group: + p.grad = None + self.timers(COMPUTE_NORM_TIMER).start() all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, mpu=self.mpu, - grad_norm_mask=flatten_grad_norm_mask_list) + grad_norm_mask=self.flatten_grad_norm_mask_list) if self.has_moe_layers: all_groups_norm = get_norm_with_moe_layers(all_groups_norm, @@ -334,7 +340,7 @@ def step(self, closure=None): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) - + self.has_executed_step = True self.timers(UPDATE_FP16_TIMER).stop() self.timers.log(STEP_TIMERS) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 578af4161024..048e34bae4b5 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -18,7 +18,6 @@ import torch from deepspeed import comm as dist - try: from torch._six import inf except ModuleNotFoundError: @@ -419,18 +418,26 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No else: total_norm = 0. for idx, p in enumerate(parameters): - if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: # Use grad_norm_mask to avoid redundant computation of flattened gradient norm # # including, Pipeline parallelism may replicate parameters. # # replicated tensors from tensor model parallelism - mask_tensor = torch.ones_like(p, device=p.device, dtype=bool) - for mask_idx in grad_norm_mask[idx]: - mask_tensor[mask_idx[0]:mask_idx[1]] = 0 - param_norm = (p.grad.data * mask_tensor).float().norm(norm_type) + + # A loop-free implementation to create a mask tensor based on a range list, + # which is logically equivalent to the following implementation. + + # # mask_tensor = torch.zeros_like(p, device=p.device, dtype=bool) + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor[mask_idx[0]:mask_idx[1]] = True + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) + mask_tensor = torch.zeros_like(p, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), + cum_sum_pairs.view(-1)).cumsum(0).bool() + + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) else: param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. From 79cc4cef214db44fdf9ac16ac71b591aa09aaaba Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 15 Apr 2024 10:32:08 +0800 Subject: [PATCH 5/7] Modify the comments. --- deepspeed/runtime/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 048e34bae4b5..d213e00c42d6 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -397,9 +397,8 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - grad_norm_mask (List[List[Tuple[int, int]]]): A list of lists, where - each inner list contains tuples(start_idx, end_idx) of a flattened grad - + grad_norm_mask (List[Tensor]): A list of Tensor, where + each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. Returns: Total norm of the parameters (viewed as a single vector). """ From 3ebed5ea11cafc2a07aa1a98aa58ed832eebfa60 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 15 Apr 2024 14:46:17 +0800 Subject: [PATCH 6/7] update --- deepspeed/runtime/fp16/fused_optimizer.py | 12 ++++++++---- deepspeed/runtime/utils.py | 13 +++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 77824ef32149..a98681b3e9f8 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -232,10 +232,14 @@ def _get_norm_mask_idx(self, group): for p in group: grad_flat_en_idx = grad_flat_st_idx + p.numel() - if p.grad is None or self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): - group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) - else: - grad_flat_st_idx = grad_flat_en_idx + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) def step(self, closure=None): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d213e00c42d6..ee1b5655dfce 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -425,16 +425,17 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # A loop-free implementation to create a mask tensor based on a range list, # which is logically equivalent to the following implementation. - # # mask_tensor = torch.zeros_like(p, device=p.device, dtype=bool) - # # for mask_idx in grad_norm_mask[idx]: - # # mask_tensor[mask_idx[0]:mask_idx[1]] = True + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # #for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros_like(p, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), - cum_sum_pairs.view(-1)).cumsum(0).bool() - + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + # assert torch.equal(mask_tensor_, mask_tensor) param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) + else: param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type From df976ca607b2bc11101c4bb267a7850241a724f1 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 16 Apr 2024 14:37:08 +0800 Subject: [PATCH 7/7] refine comments --- deepspeed/runtime/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c8de55ea50c2..7744b2ee8b98 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -416,23 +416,20 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No else: total_norm = 0. for idx, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: - # Use grad_norm_mask to avoid redundant computation of flattened gradient norm - # # including, Pipeline parallelism may replicate parameters. - # # replicated tensors from tensor model parallelism - # A loop-free implementation to create a mask tensor based on a range list, + # A loop-free implementation to create a mask tensor based on a range list # which is logically equivalent to the following implementation. - # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) - # #for mask_idx in grad_norm_mask[idx]: - # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] - # assert torch.equal(mask_tensor_, mask_tensor) + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) else: