Skip to content

Commit

Permalink
support bf16_optimizer moe expert parallel training and moe EP grad_s…
Browse files Browse the repository at this point in the history
…cale/grad_norm fix (#5259)

- bf16 moe EP requires different partitions and this will impact dp
gradient allreduce, zero1 params allgather, as well as gradient_norm
allreduce. Currently, the bf16_optimizer does not correctly partition
the group. fix and support bf16 type training.
- fix calculation of moe ep grad scale and grad_norm for bf16&fp16

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
inkcherry and tjruwase authored Mar 27, 2024
1 parent 19670b4 commit 31cdc51
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 73 deletions.
4 changes: 4 additions & 0 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer(
param_groups.append(param_group)

return param_groups


def is_moe_param_group(param_group):
return param_group.get('moe', False)
112 changes: 86 additions & 26 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from packaging import version as pkg_version

from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand All @@ -40,7 +40,8 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False):
immediate_grad_update=False,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand All @@ -59,7 +60,11 @@ def __init__(self,
self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group)
self.has_moe_layers = has_moe_layers
self.non_expert_gradients = []
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
if self.has_moe_layers:
self._configure_moe_settings()

# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
Expand Down Expand Up @@ -90,11 +95,26 @@ def __init__(self,

see_memory_usage('end bf16_optimizer', force=True)

def _configure_moe_settings(self):
assert any(
[is_moe_param_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"

for i, group in enumerate(self.optimizer.param_groups):
if is_moe_param_group(group):
assert all([is_moe_param(param)
for param in group['params']]), "All params in MoE group must be MoE params"
self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
self.expert_gradients = {}
if self.has_moe_layers:
for key in groups._get_expert_data_parallel_group_dict().keys():
self.expert_gradients[key] = []

def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]

for i, param_group in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
see_memory_usage(f'before initializing group {i}', force=True)

partition_id = dist.get_rank(group=self.real_dp_process_group[i])
Expand All @@ -106,17 +126,16 @@ def _setup_for_real_optimizer(self):
# create flat bf16 params
self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned(self.bf16_groups[i],
self.nccl_start_alignment_factor * dp_world_size))

self.nccl_start_alignment_factor * real_dp_world_size))
# Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
flat_tensor=self.bf16_groups_flat[i])

# divide flat weights into equal sized partitions
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
for dp_index in range(dp_world_size)
for dp_index in range(real_dp_world_size)
]
self.bf16_partitioned_groups.append(bf16_dp_partitions)

Expand All @@ -127,8 +146,12 @@ def _setup_for_real_optimizer(self):
num_elem_list = [t.numel() for t in self.bf16_groups[i]]

# create fp32 gradients
self.fp32_groups_gradients_flat.append(
torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
if self.has_moe_layers and is_moe_param_group(param_group):
self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
else:
self.non_expert_gradients.append(fp32_flat_buffer)

# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
Expand Down Expand Up @@ -191,11 +214,12 @@ def _create_param_mapping(self):
return param_mapping

def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])

# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
Expand Down Expand Up @@ -257,10 +281,18 @@ def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')

all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
all_groups_norm = non_expert_groups_norm
if self.has_moe_layers:
all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
mpu=self.mpu,
expert_tensors=expert_grads_for_norm,
norm_type=self.norm_type)

self._global_grad_norm = all_groups_norm

assert all_groups_norm > 0.
Expand Down Expand Up @@ -336,27 +368,55 @@ def update_hp_grads(self, clear_lp_grads=False):

@torch.no_grad()
def get_grads_for_reduction(self):
return self.fp32_groups_gradients_flat
if self.has_moe_layers:
return self.non_expert_gradients, self.expert_gradients
return self.non_expert_gradients, {}

@torch.no_grad()
def get_grads_for_norm(self, for_clipping=False):
grads = []
"""
Returns:
tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
If for_clipping, return all gradients.
Otherwise, separate and return dict of expert_grad and list of non_expert_grad
"""
# (grads, expert_group_name)
expert_grads_for_norm = {}

# grads
non_expert_grads_for_norm = []
all_grads_for_clip = []

tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
assert len(self.bf16_groups) == len(self.optimizer.param_groups)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
continue

if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
# skip duplicated parameters. perform norm only on cards with tp_rank=0.
# non-duplicated parameters include:
# - Parameters with tp: Use allreducesum of mp_group.
# - Moe Parameters with ep: Use allreducesum of ep_group.
if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
continue

if not self.fp32_groups_has_gradients[i][j]:
continue

grads.append(self.fp32_groups_gradients[i][j])

return grads
if not for_clipping:
param_group = self.optimizer.param_groups[i]
if self.has_moe_layers and is_moe_param_group(param_group):
if param_group['name'] not in expert_grads_for_norm:
expert_grads_for_norm[param_group['name']] = []
expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
else:
non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
else:
all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
if not for_clipping:
return non_expert_grads_for_norm, expert_grads_for_norm
return all_grads_for_clip

@torch.no_grad()
def update_lp_params(self):
Expand Down
Loading

0 comments on commit 31cdc51

Please sign in to comment.