Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support bf16_optimizer moe expert parallel training and moe EP grad_scale/grad_norm fix #5259

Merged
merged 20 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer(
param_groups.append(param_group)

return param_groups


def is_moe_param_group(param_group):
return param_group.get('moe', False)
112 changes: 86 additions & 26 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from packaging import version as pkg_version

from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand All @@ -40,7 +40,8 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False):
immediate_grad_update=False,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand All @@ -59,7 +60,11 @@ def __init__(self,
self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group)
self.has_moe_layers = has_moe_layers
self.non_expert_gradients = []
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
if self.has_moe_layers:
self._configure_moe_settings()

# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
Expand Down Expand Up @@ -90,11 +95,26 @@ def __init__(self,

see_memory_usage('end bf16_optimizer', force=True)

def _configure_moe_settings(self):
assert any(
[is_moe_param_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"

for i, group in enumerate(self.optimizer.param_groups):
if is_moe_param_group(group):
assert all([is_moe_param(param)
for param in group['params']]), "All params in MoE group must be MoE params"
self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
self.expert_gradients = {}
if self.has_moe_layers:
for key in groups._get_expert_data_parallel_group_dict().keys():
self.expert_gradients[key] = []

def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]

for i, param_group in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
see_memory_usage(f'before initializing group {i}', force=True)

partition_id = dist.get_rank(group=self.real_dp_process_group[i])
Expand All @@ -106,17 +126,16 @@ def _setup_for_real_optimizer(self):
# create flat bf16 params
self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned(self.bf16_groups[i],
self.nccl_start_alignment_factor * dp_world_size))

self.nccl_start_alignment_factor * real_dp_world_size))
# Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
flat_tensor=self.bf16_groups_flat[i])

# divide flat weights into equal sized partitions
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
for dp_index in range(dp_world_size)
for dp_index in range(real_dp_world_size)
]
self.bf16_partitioned_groups.append(bf16_dp_partitions)

Expand All @@ -127,8 +146,12 @@ def _setup_for_real_optimizer(self):
num_elem_list = [t.numel() for t in self.bf16_groups[i]]

# create fp32 gradients
self.fp32_groups_gradients_flat.append(
torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
if self.has_moe_layers and is_moe_param_group(param_group):
self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
else:
self.non_expert_gradients.append(fp32_flat_buffer)

# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
Expand Down Expand Up @@ -191,11 +214,12 @@ def _create_param_mapping(self):
return param_mapping

def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])

# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
Expand Down Expand Up @@ -257,10 +281,18 @@ def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')

all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
all_groups_norm = non_expert_groups_norm
if self.has_moe_layers:
all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
mpu=self.mpu,
expert_tensors=expert_grads_for_norm,
norm_type=self.norm_type)

self._global_grad_norm = all_groups_norm

assert all_groups_norm > 0.
Expand Down Expand Up @@ -336,27 +368,55 @@ def update_hp_grads(self, clear_lp_grads=False):

@torch.no_grad()
def get_grads_for_reduction(self):
return self.fp32_groups_gradients_flat
if self.has_moe_layers:
return self.non_expert_gradients, self.expert_gradients
return self.non_expert_gradients, {}

@torch.no_grad()
def get_grads_for_norm(self, for_clipping=False):
grads = []
"""
Returns:
tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
If for_clipping, return all gradients.
Otherwise, separate and return dict of expert_grad and list of non_expert_grad
"""
# (grads, expert_group_name)
expert_grads_for_norm = {}

# grads
non_expert_grads_for_norm = []
all_grads_for_clip = []

tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
assert len(self.bf16_groups) == len(self.optimizer.param_groups)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
continue

if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
# skip duplicated parameters. perform norm only on cards with tp_rank=0.
# non-duplicated parameters include:
# - Parameters with tp: Use allreducesum of mp_group.
# - Moe Parameters with ep: Use allreducesum of ep_group.
if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
continue

if not self.fp32_groups_has_gradients[i][j]:
continue

grads.append(self.fp32_groups_gradients[i][j])

return grads
if not for_clipping:
param_group = self.optimizer.param_groups[i]
if self.has_moe_layers and is_moe_param_group(param_group):
if param_group['name'] not in expert_grads_for_norm:
expert_grads_for_norm[param_group['name']] = []
expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
else:
non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
else:
all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
if not for_clipping:
return non_expert_grads_for_norm, expert_grads_for_norm
return all_grads_for_clip

@torch.no_grad()
def update_lp_params(self):
Expand Down
Loading
Loading