diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index daf143919558..2e348de63454 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -11,7 +11,7 @@ from typing import Optional import torch from deepspeed import comm as dist -from .layers import LinearAllreduce, LinearLayer +from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw @@ -318,6 +318,11 @@ def _replace(self, child, name, conv_linear_layer): del data setattr(child, "replaced", True) + if name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce( + torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), + child.bias if child.bias is None else torch.nn.parameter.Parameter( + child.bias.to(get_accelerator().current_device_name())), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group) else: @@ -436,3 +441,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): self.update_mp_params(child) self._replace_module(child, name, class_name) return r_module + + def _replace_last_linear_module(self, r_module): + if hasattr(r_module, "lm_head"): + name = "lm_head" + child = r_module.lm_head + elif hasattr(r_module, "embed_out"): + name = "embed_out" + child = r_module.embed_out + else: + return r_module + if child.__class__ in self.linear_policies: + setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer)) + return r_module diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index aa29651ec4cf..7a565560dec9 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -29,6 +29,36 @@ def forward(self, input): return output +class LmHeadLinearAllreduce(nn.Module): + + def __init__( + self, + weight, + rank, + world_size, + bias=None, + mp_group=None, + ): + super(LmHeadLinearAllreduce, self).__init__() + self.weight = weight + self.bias = bias + self.mp_group = mp_group + self.rank = rank + self.world_size = world_size + + def forward(self, input): + assert input.shape[ + -1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]' + input_shard = input.shape[-1] // self.world_size + output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard], + self.weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.inference_all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + class LinearLayer(nn.Module): def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f9dd921b9ae1..8666372fa3f4 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -275,6 +275,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): _autotp.update_linear_policies() # 4. Replace modules + if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears: + return _autotp._replace_last_linear_module(module) return _autotp._replace_module(module) def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): @@ -304,6 +306,13 @@ def set_lm_head(module): if embedding_weight is not None and hasattr(module, "lm_head") and hasattr( module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight + # enable tensor parallel for the last linear + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta: + module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") + elif hasattr(module, "embed_out") and hasattr(module.embed_out, + "weight") and not module.embed_out.weight.is_meta: + module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") + return module if checkpoint_dict is not None and not config.replace_with_kernel_inject: # AutoTP shard loading @@ -318,7 +327,7 @@ def set_lm_head(module): checkpoint=checkpoint_file) pbar.update(1) gc.collect() - set_lm_head(replaced_module) + replaced_module = set_lm_head(replaced_module) else: replaced_module = replace_module(model=model, orig_class=orig_layer_impl,