From 31cdc51627048dd1e4697faf639b50e499991638 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 27 Mar 2024 23:26:41 +0800 Subject: [PATCH] support bf16_optimizer moe expert parallel training and moe EP grad_scale/grad_norm fix (#5259) - bf16 moe EP requires different partitions and this will impact dp gradient allreduce, zero1 params allgather, as well as gradient_norm allreduce. Currently, the bf16_optimizer does not correctly partition the group. fix and support bf16 type training. - fix calculation of moe ep grad scale and grad_norm for bf16&fp16 --------- Co-authored-by: Olatunji Ruwase --- deepspeed/moe/utils.py | 4 + deepspeed/runtime/bf16_optimizer.py | 112 ++++++++++++++---- deepspeed/runtime/engine.py | 60 ++++++---- deepspeed/runtime/fp16/fused_optimizer.py | 38 +++--- deepspeed/runtime/utils.py | 62 +++++++++- .../unit/runtime/half_precision/test_fp16.py | 5 +- 6 files changed, 208 insertions(+), 73 deletions(-) diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index 8e1faffc3541d..f52fe2e3442de 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer( param_groups.append(param_group) return param_groups + + +def is_moe_param_group(param_group): + return param_group.get('moe', False) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 82c8dda423a60..4ec603af1505b 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -12,13 +12,13 @@ from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.runtime import ZeROOptimizer from packaging import version as pkg_version - from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage, graph_process) - -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states + is_model_parallel_parameter, see_memory_usage, graph_process, + get_norm_with_moe_layers) +from deepspeed.moe.utils import is_moe_param, is_moe_param_group +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -40,7 +40,8 @@ def __init__(self, timers=None, grad_acc_dtype=None, graph_harvesting=False, - immediate_grad_update=False): + immediate_grad_update=False, + has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -59,7 +60,11 @@ def __init__(self, self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) + self.has_moe_layers = has_moe_layers + self.non_expert_gradients = [] self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))] + if self.has_moe_layers: + self._configure_moe_settings() # Use torch (un)flatten ops self.flatten = _flatten_dense_tensors @@ -90,11 +95,26 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) + def _configure_moe_settings(self): + assert any( + [is_moe_param_group(group) for group in self.optimizer.param_groups] + ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" + + for i, group in enumerate(self.optimizer.param_groups): + if is_moe_param_group(group): + assert all([is_moe_param(param) + for param in group['params']]), "All params in MoE group must be MoE params" + self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name']) + self.expert_gradients = {} + if self.has_moe_layers: + for key in groups._get_expert_data_parallel_group_dict().keys(): + self.expert_gradients[key] = [] + def _setup_for_real_optimizer(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) - self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))] + self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group] for i, param_group in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) @@ -106,17 +126,16 @@ def _setup_for_real_optimizer(self): # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned(self.bf16_groups[i], - self.nccl_start_alignment_factor * dp_world_size)) - + self.nccl_start_alignment_factor * real_dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) - for dp_index in range(dp_world_size) + for dp_index in range(real_dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) @@ -127,8 +146,12 @@ def _setup_for_real_optimizer(self): num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients - self.fp32_groups_gradients_flat.append( - torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)) + fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype) + self.fp32_groups_gradients_flat.append(fp32_flat_buffer) + if self.has_moe_layers and is_moe_param_group(param_group): + self.expert_gradients[param_group['name']].append(fp32_flat_buffer) + else: + self.non_expert_gradients.append(fp32_flat_buffer) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i], @@ -191,11 +214,12 @@ def _create_param_mapping(self): return param_mapping def _link_all_hp_params(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, _ in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) + # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size flat_hp_partition = self.fp32_groups_flat_partition[i] link_hp_params(lp_param_list=self.bf16_groups[i], flat_hp_partition=flat_hp_partition, @@ -257,10 +281,18 @@ def step(self, closure=None): if closure is not None: raise NotImplementedError(f'{self.__class__} does not support closure.') - all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), - mpu=self.mpu, - norm_type=self.norm_type, - use_graph=self.graph_harvesting) + non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm() + non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm, + mpu=self.mpu, + norm_type=self.norm_type, + use_graph=self.graph_harvesting) + all_groups_norm = non_expert_groups_norm + if self.has_moe_layers: + all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) + self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -336,27 +368,55 @@ def update_hp_grads(self, clear_lp_grads=False): @torch.no_grad() def get_grads_for_reduction(self): - return self.fp32_groups_gradients_flat + if self.has_moe_layers: + return self.non_expert_gradients, self.expert_gradients + return self.non_expert_gradients, {} @torch.no_grad() def get_grads_for_norm(self, for_clipping=False): - grads = [] + """ + Returns: + tuple[list[Tensor], dict[ep_name, List[Tensor]] | list: + If for_clipping, return all gradients. + Otherwise, separate and return dict of expert_grad and list of non_expert_grad + """ + # (grads, expert_group_name) + expert_grads_for_norm = {} + + # grads + non_expert_grads_for_norm = [] + all_grads_for_clip = [] + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + assert len(self.bf16_groups) == len(self.optimizer.param_groups) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if not for_clipping: if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue - if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)): + # skip duplicated parameters. perform norm only on cards with tp_rank=0. + # non-duplicated parameters include: + # - Parameters with tp: Use allreducesum of mp_group. + # - Moe Parameters with ep: Use allreducesum of ep_group. + if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)): continue if not self.fp32_groups_has_gradients[i][j]: continue - - grads.append(self.fp32_groups_gradients[i][j]) - - return grads + if not for_clipping: + param_group = self.optimizer.param_groups[i] + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j]) + else: + non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j]) + else: + all_grads_for_clip.append(self.fp32_groups_gradients[i][j]) + if not for_clipping: + return non_expert_grads_for_norm, expert_grads_for_norm + return all_grads_for_clip @torch.no_grad() def update_lp_params(self): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 174e699c52029..bd2e91431aff3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1478,7 +1478,8 @@ def _configure_bf16_optimizer(self, optimizer): timers=timers, grad_acc_dtype=self.get_data_types()[1], graph_harvesting=self.graph_harvesting(), - immediate_grad_update=self._config.bfloat16_immediate_grad_update) + immediate_grad_update=self._config.bfloat16_immediate_grad_update, + has_moe_layers=self.has_moe_layers) return optimizer @@ -1924,9 +1925,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: grads = None - if hasattr(self.optimizer, "get_grads_for_reduction"): - # This is currently for BF16 optimizer - grads = self.optimizer.get_grads_for_reduction() self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) @instrument_w_nvtx @@ -2335,7 +2333,7 @@ def _report_progress(self, step): mom = self.get_mom() log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) - def allreduce_bucket(self, bucket, dp_group): + def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -2343,16 +2341,18 @@ def allreduce_bucket(self, bucket, dp_group): if self.communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(self.communication_data_type) + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_predivide_factor() != 1.0: tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor()) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: - if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + if self.gradient_predivide_factor() != dp_world_size: + tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size) else: - tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(1. / dp_world_size) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -2360,23 +2360,23 @@ def allreduce_bucket(self, bucket, dp_group): return tensor - def allreduce_and_copy(self, small_bucket, dp_group): - allreduced = self.allreduce_bucket(small_bucket, dp_group) + def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None): + allreduced = self.allreduce_bucket(small_bucket, dp_group, dp_world_size) for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000): + def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) small_bucket = [] numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) def _get_gradients_for_reduction(self): non_expert_grads = [] @@ -2427,26 +2427,35 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): + # to maintain the gradients value unaffected by ep_size setting, + # utilize dp_world_size for allreduce average + dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) for ep_name, expert_grads_group in expert_grads.items(): + ep_dp_group = groups._get_expert_data_parallel_group(ep_name) split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( expert_grads_group) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, groups._get_expert_data_parallel_group(ep_name)) + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple # Separate between diff groups self.allreduce_no_retain(dense_bucket, - dp_group=groups._get_expert_data_parallel_group(ep_name), - numel_per_bucket=elements_per_buffer) + dp_group=ep_dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: - non_expert_grads, expert_grads = self._get_gradients_for_reduction() + if hasattr(self.optimizer, "get_grads_for_reduction"): + # This is currently for BF16 optimizer + non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction() + else: + non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads @@ -2456,8 +2465,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) if self.has_moe_layers: self._reduce_expert_gradients(expert_grads, elements_per_buffer) - def sparse_allreduce_no_retain(self, bucket, dp_group): - allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group) + def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None): + allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size) # Densify sparse tensor and copy back to original location for tensor in allreduced_sparses: if tensor.is_sparse: @@ -2465,13 +2474,13 @@ def sparse_allreduce_no_retain(self, bucket, dp_group): else: tensor.orig_dense_tensor.copy_(tensor.to_dense()) - def sparse_allreduce_bucket(self, bucket, dp_group): + def sparse_allreduce_bucket(self, bucket, dp_group, dp_world_size=None): sparse_list = [] for sparse in bucket: - sparse_list.append(self.sparse_allreduce(sparse, dp_group)) + sparse_list.append(self.sparse_allreduce(sparse, dp_group, dp_world_size)) return sparse_list - def sparse_allreduce(self, sparse, dp_group): + def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): original_data_type = sparse.values.dtype if self.communication_data_type != sparse.values.dtype: if self.communication_data_type in (torch.float16, torch.bfloat16): @@ -2483,12 +2492,13 @@ def sparse_allreduce(self, sparse, dp_group): indices = sparse.indices values = sparse.values + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / - (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size))) else: - values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 182f806c839c3..416642a89901c 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,12 +11,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version +from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE -from deepspeed.utils import groups, logger, log_dist -from deepspeed import comm as dist +from deepspeed.utils import logger, log_dist from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import is_moe_param_group OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -237,6 +237,10 @@ def step(self, closure=None): return self.overflow grads_groups_flat = [] + non_experts_grads_for_norm = [] + expert_grads_for_norm = {} + assert len(self.fp16_groups) == len(self.optimizer.param_groups) + for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype @@ -250,15 +254,25 @@ def step(self, closure=None): p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] + param_group = self.optimizer.param_groups[i] + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) + else: + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) self.timers(COMPUTE_NORM_TIMER).start() - all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) + all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) self.timers(COMPUTE_NORM_TIMER).stop() if self.has_moe_layers: - all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm) + all_groups_norm = get_norm_with_moe_layers(all_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) @@ -290,20 +304,6 @@ def step(self, closure=None): return self.overflow - def _get_norm_with_moe_layers(self, all_groups_norm): - #all_groups_norm_old = all_groups_norm - # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce - if self.using_pipeline: - pg = self.deepspeed.mpu.get_data_parallel_group() - else: - pg = groups._get_data_parallel_group() - scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") - return all_groups_norm - def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True): # compute combined scale factor for this group combined_scale = self.cur_scale diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d1ebe4b2f83db..e068f4a48b4af 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -205,6 +205,17 @@ def move_to_device(item, device, criterion_func): return item +def get_norm_with_moe_layers_fast(all_groups_norm, group): + # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. + # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce + scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=group) + all_groups_norm = scaled_norm_tensor.item() + #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") + return all_groups_norm + + class CheckOverflow(object): '''Checks for overflow in gradient across parallel process''' @@ -861,7 +872,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -884,7 +895,9 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + if moe_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group) + total_norm = total_norm_cuda[0].item() else: if use_graph: if 'norm_tensors_compute_buffer' not in graph_cache: @@ -906,6 +919,9 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type): total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + if moe_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group) + total_norm = total_norm_cuda[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: @@ -1048,3 +1064,45 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2): + """ Compute the global norm with MoE experts + + Inputs: + non_expert_norm (float) : the calculated norm of the non-expert params + expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors + norm_type (int): the norm to use + + Returns: + if norm is (-/+) inf, returns -1 + otherwise the global norm (float) + """ + + def to_tensor(v): + return get_accelerator().FloatTensor(float(v)).detach() + + group_norms = [non_expert_norm] + for exp_name, tensors in expert_tensors.items(): + group_norm = get_global_norm_of_tensors(input_tensors=tensors, + mpu=mpu, + norm_type=norm_type, + use_graph=False, + moe_ep_group=groups._get_expert_parallel_group(exp_name)) + group_norms.append(group_norm) + + # check if all norms are valid + group_norms = torch.stack([to_tensor(norm) for norm in group_norms]) + if group_norms.eq(-1).any(): + return -1 + + # combine norms + if norm_type == inf: + total_norm = group_norms.max().item() + else: + total_norm = group_norms.pow(norm_type).sum() + total_norm = total_norm.item()**(1. / norm_type) + if total_norm == float('inf') or total_norm == -float('inf'): + total_norm = -1 + + return total_norm diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index e54fe352bf5b9..9229794b39f87 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -13,6 +13,7 @@ from deepspeed.runtime.utils import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer try: from apex import amp # noqa: F401 # type: ignore @@ -215,8 +216,10 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) # initialize MoE model = SimpleMoEModel(hidden_dim, ep_size=2) + param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'} + params = split_params_into_different_moe_groups_for_optimizer(param_group) # optimizer = torch.optim.AdamW(params=model.parameters()) - optimizer = FusedAdam(params=model.parameters()) + optimizer = FusedAdam(params=params) engine, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer,