diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index c31b9671296f..9fb9eba44a38 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -238,8 +238,10 @@ def get_sparse_gradients_enabled(param_dict): return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT) -def get_communication_data_type(param_dict): - val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT) +def get_communication_data_type(param_dict, + comm_type=COMMUNICATION_DATA_TYPE, + comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT): + val = get_scalar_param(param_dict, comm_type, comm_data_type_default) val = val.lower() if val is not None else val if val is None: return val # we must determine it by other parameters @@ -784,6 +786,8 @@ def _initialize_params(self, param_dict): self.disable_allgather = get_disable_allgather(param_dict) self.communication_data_type = get_communication_data_type(param_dict) + self.seq_parallel_communication_data_type = get_communication_data_type( + param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT) self.prescale_gradients = get_prescale_gradients(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 0bdac2557847..cc493ee007c5 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -223,6 +223,19 @@ COMMUNICATION_DATA_TYPE = "communication_data_type" COMMUNICATION_DATA_TYPE_DEFAULT = None +########################################################### +# Gradient communication data type for sequence parallelism +########################################################### +# Supported types: ['fp16', 'bf16','fp32'] +# Default value is fp32 +# Users can configure in ds_config.json as below example: +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_FORMAT = ''' +Optional comm data type for seq paralleism should be set as: +"seq_parallel_communication_data_type": "fp32" +''' +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_comm_data_type" +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32" + ######################################### # Scale/predivide gradients before allreduce ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d2cb93394a53..e5f9d0ec8d03 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -808,6 +808,10 @@ def communication_data_type(self): return torch.float32 + @communication_data_type.setter + def communication_data_type(self, value): + self._config.communication_data_type = value + def postscale_gradients(self): return not self._config.prescale_gradients @@ -1114,6 +1118,9 @@ def _configure_distributed_model(self, model): self.mp_world_size = groups._get_model_parallel_world_size() self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() + if self.sequence_parallel_size > 1: + self.communication_data_type = self._config.seq_parallel_communication_data_type if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() @@ -2370,7 +2377,7 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): if self.pipeline_parallelism: dp_group = self.mpu.get_data_parallel_group() else: - dp_group = groups._get_data_parallel_group() + dp_group = groups._get_sequence_data_parallel_group() if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) @@ -2431,9 +2438,10 @@ def sparse_allreduce(self, sparse, dp_group): if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + values.mul_(self.gradient_predivide_factor() / + (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) else: - values.mul_(1. / dist.get_world_size(group=dp_group)) + values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3c5128744848..38539ba57033 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -219,6 +219,7 @@ def __init__( self.reduce_scatter = reduce_scatter self.dp_process_group = self.parameter_offload.dp_process_group + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() self.all2all_process_group = all2all_process_group @@ -1177,7 +1178,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tenso world_sz = dist.get_world_size(self.dp_process_group) rank = dist.get_rank(self.dp_process_group) - buffer_to_reduce.div_(world_sz) + buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) @@ -1476,7 +1477,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing" diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 32fe74b02a58..c3b4160ebf31 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -31,6 +31,7 @@ from deepspeed.utils import link_hp_params from deepspeed.checkpoint import enable_universal_checkpoint +from deepspeed.utils import groups # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False @@ -182,7 +183,7 @@ def __init__(self, self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group - + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() #expert parallel group self.ep_process_group = expert_parallel_group @@ -941,9 +942,10 @@ def gradient_reduction_w_predivide(self, tensor): dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) + tensor_to_allreduce.mul_(self.gradient_predivide_factor / + (dp_world_size / float(self.sequence_parallel_size))) else: - tensor_to_allreduce.div_(dp_world_size) + tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -985,7 +987,7 @@ def average_tensor(self, tensor): if self.ipg_bucket_has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.dp_process_group - grad_reduc.data.div_(dist.get_world_size(group=process_group)) + grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) partition_ids = self.param_to_partition_ids[i][param_id] assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids @@ -1025,7 +1027,7 @@ def average_tensor(self, tensor): prev_id, prev_process_group = partition_id, process_group if not self.ipg_bucket_has_moe_params: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) tensor_to_reduce = tensor if self.communication_data_type != tensor.dtype: @@ -1395,7 +1397,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): tensor_to_allreduce = tensor - if pg_correctness_test: + if pg_correctness_test or self.sequence_parallel_size > 1: communication_data_type = torch.float32 else: communication_data_type = self.communication_data_type @@ -1403,7 +1405,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing"