Skip to content

Commit

Permalink
Fix the sequence-parallelism for the dense model architecture (#4530)
Browse files Browse the repository at this point in the history
Co-authored-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Sam Ade Jacobs <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
5 people authored Oct 25, 2023
1 parent f15cccf commit ec029e7
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
8 changes: 6 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ def get_sparse_gradients_enabled(param_dict):
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)


def get_communication_data_type(param_dict):
val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT)
def get_communication_data_type(param_dict,
comm_type=COMMUNICATION_DATA_TYPE,
comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT):
val = get_scalar_param(param_dict, comm_type, comm_data_type_default)
val = val.lower() if val is not None else val
if val is None:
return val # we must determine it by other parameters
Expand Down Expand Up @@ -784,6 +786,8 @@ def _initialize_params(self, param_dict):

self.disable_allgather = get_disable_allgather(param_dict)
self.communication_data_type = get_communication_data_type(param_dict)
self.seq_parallel_communication_data_type = get_communication_data_type(
param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT)
self.prescale_gradients = get_prescale_gradients(param_dict)
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
Expand Down
13 changes: 13 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@
COMMUNICATION_DATA_TYPE = "communication_data_type"
COMMUNICATION_DATA_TYPE_DEFAULT = None

###########################################################
# Gradient communication data type for sequence parallelism
###########################################################
# Supported types: ['fp16', 'bf16','fp32']
# Default value is fp32
# Users can configure in ds_config.json as below example:
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_FORMAT = '''
Optional comm data type for seq paralleism should be set as:
"seq_parallel_communication_data_type": "fp32"
'''
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_comm_data_type"
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32"

#########################################
# Scale/predivide gradients before allreduce
#########################################
Expand Down
14 changes: 11 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,10 @@ def communication_data_type(self):

return torch.float32

@communication_data_type.setter
def communication_data_type(self, value):
self._config.communication_data_type = value

def postscale_gradients(self):
return not self._config.prescale_gradients

Expand Down Expand Up @@ -1114,6 +1118,9 @@ def _configure_distributed_model(self, model):
self.mp_world_size = groups._get_model_parallel_world_size()
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type

if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
Expand Down Expand Up @@ -2370,7 +2377,7 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
else:
dp_group = groups._get_data_parallel_group()
dp_group = groups._get_sequence_data_parallel_group()

if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
Expand Down Expand Up @@ -2431,9 +2438,10 @@ def sparse_allreduce(self, sparse, dp_group):

if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
values.mul_(self.gradient_predivide_factor() /
(dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
else:
values.mul_(1. / dist.get_world_size(group=dp_group))
values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))

indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
self.reduce_scatter = reduce_scatter

self.dp_process_group = self.parameter_offload.dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()

self.all2all_process_group = all2all_process_group

Expand Down Expand Up @@ -1177,7 +1178,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tenso

world_sz = dist.get_world_size(self.dp_process_group)
rank = dist.get_rank(self.dp_process_group)
buffer_to_reduce.div_(world_sz)
buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size))

dist.all_reduce(buffer_to_reduce, group=self.dp_process_group)

Expand Down Expand Up @@ -1476,7 +1477,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None):
if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)

tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

if rank is None:
# "All Reducing"
Expand Down
16 changes: 9 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(self,
self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'

self.dp_process_group = dp_process_group

self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
#expert parallel group
self.ep_process_group = expert_parallel_group

Expand Down Expand Up @@ -941,9 +942,10 @@ def gradient_reduction_w_predivide(self, tensor):
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
(dp_world_size / float(self.sequence_parallel_size)))
else:
tensor_to_allreduce.div_(dp_world_size)
tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size))
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
Expand Down Expand Up @@ -985,7 +987,7 @@ def average_tensor(self, tensor):
if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.dp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group))
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
Expand Down Expand Up @@ -1025,7 +1027,7 @@ def average_tensor(self, tensor):
prev_id, prev_process_group = partition_id, process_group

if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

tensor_to_reduce = tensor
if self.communication_data_type != tensor.dtype:
Expand Down Expand Up @@ -1395,15 +1397,15 @@ def allreduce_bucket(self, bucket, rank=None, log=None):

tensor_to_allreduce = tensor

if pg_correctness_test:
if pg_correctness_test or self.sequence_parallel_size > 1:
communication_data_type = torch.float32
else:
communication_data_type = self.communication_data_type

if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)

tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

if rank is None:
# "All Reducing"
Expand Down

0 comments on commit ec029e7

Please sign in to comment.