diff --git a/deepspeed/autotuning/utils.py b/deepspeed/autotuning/utils.py index 8c9a5fa85bf2..b851353520fb 100644 --- a/deepspeed/autotuning/utils.py +++ b/deepspeed/autotuning/utils.py @@ -42,7 +42,7 @@ def find_replace_str(value, replace_dict): if not isinstance(value, str): return str(value) - matches = re.findall(r"\$[A-Za-z0-9_]+", value) + matches = re.findall(r"\$[\w]+", value) for var in matches: var_key = var.replace("$", "").lower() if var_key == "nvme_path": diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index bf9c2d74c635..88f7086518e8 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -133,7 +133,7 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -303,6 +303,9 @@ def tp_parser(model): elif 'self_attention.dense' in layer and 'falcon' in str( type(module)): # this is a hack to get the right linear layer for this model! gem_list = gem_list + [layer] + # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. + elif 'w2' in layer and 'Mixtral' in str(type(module)): + gem_list = gem_list + [layer] layer_list = [] if gem_list != []: @@ -322,6 +325,9 @@ def _replace(self, child, name, conv_linear_layer): return weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) + # For mixtral-7x8b, need to skip MoE gate linear replace. + if name == "block_sparse_moe.gate": + return child if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index 8e1faffc3541..f52fe2e3442d 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer( param_groups.append(param_group) return param_groups + + +def is_moe_param_group(param_group): + return param_group.get('moe', False) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 82c8dda423a6..4ec603af1505 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -12,13 +12,13 @@ from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.runtime import ZeROOptimizer from packaging import version as pkg_version - from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage, graph_process) - -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states + is_model_parallel_parameter, see_memory_usage, graph_process, + get_norm_with_moe_layers) +from deepspeed.moe.utils import is_moe_param, is_moe_param_group +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -40,7 +40,8 @@ def __init__(self, timers=None, grad_acc_dtype=None, graph_harvesting=False, - immediate_grad_update=False): + immediate_grad_update=False, + has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -59,7 +60,11 @@ def __init__(self, self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) + self.has_moe_layers = has_moe_layers + self.non_expert_gradients = [] self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))] + if self.has_moe_layers: + self._configure_moe_settings() # Use torch (un)flatten ops self.flatten = _flatten_dense_tensors @@ -90,11 +95,26 @@ def __init__(self, see_memory_usage('end bf16_optimizer', force=True) + def _configure_moe_settings(self): + assert any( + [is_moe_param_group(group) for group in self.optimizer.param_groups] + ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" + + for i, group in enumerate(self.optimizer.param_groups): + if is_moe_param_group(group): + assert all([is_moe_param(param) + for param in group['params']]), "All params in MoE group must be MoE params" + self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name']) + self.expert_gradients = {} + if self.has_moe_layers: + for key in groups._get_expert_data_parallel_group_dict().keys(): + self.expert_gradients[key] = [] + def _setup_for_real_optimizer(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) - self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))] + self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group] for i, param_group in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) @@ -106,17 +126,16 @@ def _setup_for_real_optimizer(self): # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned(self.bf16_groups[i], - self.nccl_start_alignment_factor * dp_world_size)) - + self.nccl_start_alignment_factor * real_dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) - for dp_index in range(dp_world_size) + for dp_index in range(real_dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) @@ -127,8 +146,12 @@ def _setup_for_real_optimizer(self): num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients - self.fp32_groups_gradients_flat.append( - torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)) + fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype) + self.fp32_groups_gradients_flat.append(fp32_flat_buffer) + if self.has_moe_layers and is_moe_param_group(param_group): + self.expert_gradients[param_group['name']].append(fp32_flat_buffer) + else: + self.non_expert_gradients.append(fp32_flat_buffer) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i], @@ -191,11 +214,12 @@ def _create_param_mapping(self): return param_mapping def _link_all_hp_params(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, _ in enumerate(self.optimizer.param_groups): + real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i]) + # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - partition_size = self.bf16_groups_flat[i].numel() // dp_world_size + partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size flat_hp_partition = self.fp32_groups_flat_partition[i] link_hp_params(lp_param_list=self.bf16_groups[i], flat_hp_partition=flat_hp_partition, @@ -257,10 +281,18 @@ def step(self, closure=None): if closure is not None: raise NotImplementedError(f'{self.__class__} does not support closure.') - all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), - mpu=self.mpu, - norm_type=self.norm_type, - use_graph=self.graph_harvesting) + non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm() + non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm, + mpu=self.mpu, + norm_type=self.norm_type, + use_graph=self.graph_harvesting) + all_groups_norm = non_expert_groups_norm + if self.has_moe_layers: + all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) + self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -336,27 +368,55 @@ def update_hp_grads(self, clear_lp_grads=False): @torch.no_grad() def get_grads_for_reduction(self): - return self.fp32_groups_gradients_flat + if self.has_moe_layers: + return self.non_expert_gradients, self.expert_gradients + return self.non_expert_gradients, {} @torch.no_grad() def get_grads_for_norm(self, for_clipping=False): - grads = [] + """ + Returns: + tuple[list[Tensor], dict[ep_name, List[Tensor]] | list: + If for_clipping, return all gradients. + Otherwise, separate and return dict of expert_grad and list of non_expert_grad + """ + # (grads, expert_group_name) + expert_grads_for_norm = {} + + # grads + non_expert_grads_for_norm = [] + all_grads_for_clip = [] + tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + assert len(self.bf16_groups) == len(self.optimizer.param_groups) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if not for_clipping: if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated: continue - if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)): + # skip duplicated parameters. perform norm only on cards with tp_rank=0. + # non-duplicated parameters include: + # - Parameters with tp: Use allreducesum of mp_group. + # - Moe Parameters with ep: Use allreducesum of ep_group. + if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)): continue if not self.fp32_groups_has_gradients[i][j]: continue - - grads.append(self.fp32_groups_gradients[i][j]) - - return grads + if not for_clipping: + param_group = self.optimizer.param_groups[i] + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j]) + else: + non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j]) + else: + all_grads_for_clip.append(self.fp32_groups_gradients[i][j]) + if not for_clipping: + return non_expert_grads_for_norm, expert_grads_for_norm + return all_grads_for_clip @torch.no_grad() def update_lp_params(self): diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 975fb1f21501..19b169086be1 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -258,10 +258,10 @@ def get_communication_data_type(param_dict, return torch.float32 elif val == "fp16": return torch.float16 - elif val == "bfp16": + elif val == "bf16": return torch.bfloat16 - raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}") + raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bf16', 'fp32']. Got: {val}") def get_prescale_gradients(param_dict): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 174e699c5202..bd2e91431aff 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1478,7 +1478,8 @@ def _configure_bf16_optimizer(self, optimizer): timers=timers, grad_acc_dtype=self.get_data_types()[1], graph_harvesting=self.graph_harvesting(), - immediate_grad_update=self._config.bfloat16_immediate_grad_update) + immediate_grad_update=self._config.bfloat16_immediate_grad_update, + has_moe_layers=self.has_moe_layers) return optimizer @@ -1924,9 +1925,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: grads = None - if hasattr(self.optimizer, "get_grads_for_reduction"): - # This is currently for BF16 optimizer - grads = self.optimizer.get_grads_for_reduction() self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) @instrument_w_nvtx @@ -2335,7 +2333,7 @@ def _report_progress(self, step): mom = self.get_mom() log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) - def allreduce_bucket(self, bucket, dp_group): + def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -2343,16 +2341,18 @@ def allreduce_bucket(self, bucket, dp_group): if self.communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(self.communication_data_type) + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_predivide_factor() != 1.0: tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor()) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: - if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + if self.gradient_predivide_factor() != dp_world_size: + tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size) else: - tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(1. / dp_world_size) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -2360,23 +2360,23 @@ def allreduce_bucket(self, bucket, dp_group): return tensor - def allreduce_and_copy(self, small_bucket, dp_group): - allreduced = self.allreduce_bucket(small_bucket, dp_group) + def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None): + allreduced = self.allreduce_bucket(small_bucket, dp_group, dp_world_size) for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000): + def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) small_bucket = [] numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, dp_group) + self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) def _get_gradients_for_reduction(self): non_expert_grads = [] @@ -2427,26 +2427,35 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): + # to maintain the gradients value unaffected by ep_size setting, + # utilize dp_world_size for allreduce average + dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) for ep_name, expert_grads_group in expert_grads.items(): + ep_dp_group = groups._get_expert_data_parallel_group(ep_name) split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( expert_grads_group) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, groups._get_expert_data_parallel_group(ep_name)) + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple # Separate between diff groups self.allreduce_no_retain(dense_bucket, - dp_group=groups._get_expert_data_parallel_group(ep_name), - numel_per_bucket=elements_per_buffer) + dp_group=ep_dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: - non_expert_grads, expert_grads = self._get_gradients_for_reduction() + if hasattr(self.optimizer, "get_grads_for_reduction"): + # This is currently for BF16 optimizer + non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction() + else: + non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads @@ -2456,8 +2465,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) if self.has_moe_layers: self._reduce_expert_gradients(expert_grads, elements_per_buffer) - def sparse_allreduce_no_retain(self, bucket, dp_group): - allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group) + def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None): + allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size) # Densify sparse tensor and copy back to original location for tensor in allreduced_sparses: if tensor.is_sparse: @@ -2465,13 +2474,13 @@ def sparse_allreduce_no_retain(self, bucket, dp_group): else: tensor.orig_dense_tensor.copy_(tensor.to_dense()) - def sparse_allreduce_bucket(self, bucket, dp_group): + def sparse_allreduce_bucket(self, bucket, dp_group, dp_world_size=None): sparse_list = [] for sparse in bucket: - sparse_list.append(self.sparse_allreduce(sparse, dp_group)) + sparse_list.append(self.sparse_allreduce(sparse, dp_group, dp_world_size)) return sparse_list - def sparse_allreduce(self, sparse, dp_group): + def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): original_data_type = sparse.values.dtype if self.communication_data_type != sparse.values.dtype: if self.communication_data_type in (torch.float16, torch.bfloat16): @@ -2483,12 +2492,13 @@ def sparse_allreduce(self, sparse, dp_group): indices = sparse.indices values = sparse.values + if dp_world_size is None: + dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / - (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size))) else: - values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 182f806c839c..416642a89901 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,12 +11,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version +from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE -from deepspeed.utils import groups, logger, log_dist -from deepspeed import comm as dist +from deepspeed.utils import logger, log_dist from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import is_moe_param_group OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -237,6 +237,10 @@ def step(self, closure=None): return self.overflow grads_groups_flat = [] + non_experts_grads_for_norm = [] + expert_grads_for_norm = {} + assert len(self.fp16_groups) == len(self.optimizer.param_groups) + for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype @@ -250,15 +254,25 @@ def step(self, closure=None): p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] + param_group = self.optimizer.param_groups[i] + if self.has_moe_layers and is_moe_param_group(param_group): + if param_group['name'] not in expert_grads_for_norm: + expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) + else: + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) self.timers(COMPUTE_NORM_TIMER).start() - all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) + all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) self.timers(COMPUTE_NORM_TIMER).stop() if self.has_moe_layers: - all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm) + all_groups_norm = get_norm_with_moe_layers(all_groups_norm, + mpu=self.mpu, + expert_tensors=expert_grads_for_norm, + norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) @@ -290,20 +304,6 @@ def step(self, closure=None): return self.overflow - def _get_norm_with_moe_layers(self, all_groups_norm): - #all_groups_norm_old = all_groups_norm - # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce - if self.using_pipeline: - pg = self.deepspeed.mpu.get_data_parallel_group() - else: - pg = groups._get_data_parallel_group() - scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - all_groups_norm = scaled_norm_tensor.item() - #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") - return all_groups_norm - def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True): # compute combined scale factor for this group combined_scale = self.cur_scale diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d1ebe4b2f83d..e068f4a48b4a 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -205,6 +205,17 @@ def move_to_device(item, device, criterion_func): return item +def get_norm_with_moe_layers_fast(all_groups_norm, group): + # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. + # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce + scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=group) + all_groups_norm = scaled_norm_tensor.item() + #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") + return all_groups_norm + + class CheckOverflow(object): '''Checks for overflow in gradient across parallel process''' @@ -861,7 +872,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -884,7 +895,9 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + if moe_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group) + total_norm = total_norm_cuda[0].item() else: if use_graph: if 'norm_tensors_compute_buffer' not in graph_cache: @@ -906,6 +919,9 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type): total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + if moe_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group) + total_norm = total_norm_cuda[0].item()**(1. / norm_type) if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: @@ -1048,3 +1064,45 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2): + """ Compute the global norm with MoE experts + + Inputs: + non_expert_norm (float) : the calculated norm of the non-expert params + expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors + norm_type (int): the norm to use + + Returns: + if norm is (-/+) inf, returns -1 + otherwise the global norm (float) + """ + + def to_tensor(v): + return get_accelerator().FloatTensor(float(v)).detach() + + group_norms = [non_expert_norm] + for exp_name, tensors in expert_tensors.items(): + group_norm = get_global_norm_of_tensors(input_tensors=tensors, + mpu=mpu, + norm_type=norm_type, + use_graph=False, + moe_ep_group=groups._get_expert_parallel_group(exp_name)) + group_norms.append(group_norm) + + # check if all norms are valid + group_norms = torch.stack([to_tensor(norm) for norm in group_norms]) + if group_norms.eq(-1).any(): + return -1 + + # combine norms + if norm_type == inf: + total_norm = group_norms.max().item() + else: + total_norm = group_norms.pow(norm_type).sum() + total_norm = total_norm.item()**(1. / norm_type) + if total_norm == float('inf') or total_norm == -float('inf'): + total_norm = -1 + + return total_norm diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md index db1a6005f793..88a43236ce9d 100644 --- a/docs/_tutorials/accelerator-abstraction-interface.md +++ b/docs/_tutorials/accelerator-abstraction-interface.md @@ -79,13 +79,13 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam ``` # Run DeepSpeed model on different accelerators -Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using extension to DeepSpeed. DeepSpeed check whether certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example if we wish to run model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instruction in [link](https://github.com/intel/intel-extension-for-deepspeed/) +Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using an extension to DeepSpeed. DeepSpeed checks whether a certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example, if we wish to run a model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instructions in the following [link](https://github.com/intel/intel-extension-for-deepspeed/) -After the extension is installed, install DeepSpeed and run model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before install DeepSpeed. +After the extension is installed, install DeepSpeed and run the model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before installing DeepSpeed. `CUDA_Accelerator` is the default accelerator in DeepSpeed. If no other DeepSpeed accelerator extension is installed, `CUDA_Accelerator` will be used. -When run a model on different accelerator in a cloud environment, the recommended practice is provision environment for each accelerator in different env with tool such as _anaconda/miniconda/virtualenv_. When run model on different Accelerator, load the env accordingly. +When running a model on different accelerators in a cloud environment, the recommended practice is to provision an environment for each accelerator in a different env with tools such as _anaconda/miniconda/virtualenv_. When running models on different Accelerator, load the env accordingly. Note that different accelerator may have different 'flavor' of float16 or bfloat16. So it is recommended to make the model configurable for both float16 and bfloat16, in that way model code does not need to be changed when running on different accelerators. diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/runtime/half_precision/test_bf16.py index d42a4b62cd10..0af14abc3be5 100644 --- a/tests/unit/runtime/half_precision/test_bf16.py +++ b/tests/unit/runtime/half_precision/test_bf16.py @@ -288,8 +288,8 @@ def test(self, stage=2): model.step() -@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"]) -@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bfp16", "default"]) +@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"]) +@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bf16", "default"]) class TestZeroDtypeCocktail(DistributedTest): world_size = 2 @@ -304,7 +304,7 @@ def test(self, comp_type, comm_type): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") - type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"} + type_str = {torch.float16: "fp16", torch.bfloat16: "bf16"} config_dict = { "train_micro_batch_size_per_gpu": 2, diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index e54fe352bf5b..9229794b39f8 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -13,6 +13,7 @@ from deepspeed.runtime.utils import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer try: from apex import amp # noqa: F401 # type: ignore @@ -215,8 +216,10 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) # initialize MoE model = SimpleMoEModel(hidden_dim, ep_size=2) + param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'} + params = split_params_into_different_moe_groups_for_optimizer(param_group) # optimizer = torch.optim.AdamW(params=model.parameters()) - optimizer = FusedAdam(params=model.parameters()) + optimizer = FusedAdam(params=params) engine, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer,