Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Mar 27, 2024
2 parents 67260b4 + 8d98e17 commit 50b46fe
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 74 deletions.
8 changes: 7 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm"
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -303,6 +303,9 @@ def tp_parser(model):
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand All @@ -322,6 +325,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer(
param_groups.append(param_group)

return param_groups


def is_moe_param_group(param_group):
return param_group.get('moe', False)
112 changes: 86 additions & 26 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from packaging import version as pkg_version

from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand All @@ -40,7 +40,8 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False):
immediate_grad_update=False,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand All @@ -59,7 +60,11 @@ def __init__(self,
self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group)
self.has_moe_layers = has_moe_layers
self.non_expert_gradients = []
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
if self.has_moe_layers:
self._configure_moe_settings()

# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
Expand Down Expand Up @@ -90,11 +95,26 @@ def __init__(self,

see_memory_usage('end bf16_optimizer', force=True)

def _configure_moe_settings(self):
assert any(
[is_moe_param_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"

for i, group in enumerate(self.optimizer.param_groups):
if is_moe_param_group(group):
assert all([is_moe_param(param)
for param in group['params']]), "All params in MoE group must be MoE params"
self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
self.expert_gradients = {}
if self.has_moe_layers:
for key in groups._get_expert_data_parallel_group_dict().keys():
self.expert_gradients[key] = []

def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]

for i, param_group in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
see_memory_usage(f'before initializing group {i}', force=True)

partition_id = dist.get_rank(group=self.real_dp_process_group[i])
Expand All @@ -106,17 +126,16 @@ def _setup_for_real_optimizer(self):
# create flat bf16 params
self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned(self.bf16_groups[i],
self.nccl_start_alignment_factor * dp_world_size))

self.nccl_start_alignment_factor * real_dp_world_size))
# Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
flat_tensor=self.bf16_groups_flat[i])

# divide flat weights into equal sized partitions
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
for dp_index in range(dp_world_size)
for dp_index in range(real_dp_world_size)
]
self.bf16_partitioned_groups.append(bf16_dp_partitions)

Expand All @@ -127,8 +146,12 @@ def _setup_for_real_optimizer(self):
num_elem_list = [t.numel() for t in self.bf16_groups[i]]

# create fp32 gradients
self.fp32_groups_gradients_flat.append(
torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
if self.has_moe_layers and is_moe_param_group(param_group):
self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
else:
self.non_expert_gradients.append(fp32_flat_buffer)

# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
Expand Down Expand Up @@ -191,11 +214,12 @@ def _create_param_mapping(self):
return param_mapping

def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.optimizer.param_groups):
real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])

# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
Expand Down Expand Up @@ -257,10 +281,18 @@ def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')

all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
mpu=self.mpu,
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
all_groups_norm = non_expert_groups_norm
if self.has_moe_layers:
all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
mpu=self.mpu,
expert_tensors=expert_grads_for_norm,
norm_type=self.norm_type)

self._global_grad_norm = all_groups_norm

assert all_groups_norm > 0.
Expand Down Expand Up @@ -336,27 +368,55 @@ def update_hp_grads(self, clear_lp_grads=False):

@torch.no_grad()
def get_grads_for_reduction(self):
return self.fp32_groups_gradients_flat
if self.has_moe_layers:
return self.non_expert_gradients, self.expert_gradients
return self.non_expert_gradients, {}

@torch.no_grad()
def get_grads_for_norm(self, for_clipping=False):
grads = []
"""
Returns:
tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
If for_clipping, return all gradients.
Otherwise, separate and return dict of expert_grad and list of non_expert_grad
"""
# (grads, expert_group_name)
expert_grads_for_norm = {}

# grads
non_expert_grads_for_norm = []
all_grads_for_clip = []

tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
assert len(self.bf16_groups) == len(self.optimizer.param_groups)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
continue

if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
# skip duplicated parameters. perform norm only on cards with tp_rank=0.
# non-duplicated parameters include:
# - Parameters with tp: Use allreducesum of mp_group.
# - Moe Parameters with ep: Use allreducesum of ep_group.
if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
continue

if not self.fp32_groups_has_gradients[i][j]:
continue

grads.append(self.fp32_groups_gradients[i][j])

return grads
if not for_clipping:
param_group = self.optimizer.param_groups[i]
if self.has_moe_layers and is_moe_param_group(param_group):
if param_group['name'] not in expert_grads_for_norm:
expert_grads_for_norm[param_group['name']] = []
expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
else:
non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
else:
all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
if not for_clipping:
return non_expert_grads_for_norm, expert_grads_for_norm
return all_grads_for_clip

@torch.no_grad()
def update_lp_params(self):
Expand Down
Loading

0 comments on commit 50b46fe

Please sign in to comment.