diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index fe0043547860..d8655299282f 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -77,6 +77,7 @@ def initialize(args=None, dist_init_required: Optional[bool] = None, collate_fn=None, config=None, + mesh_param=None, config_params=None): """Initialize the DeepSpeed Engine. @@ -144,10 +145,22 @@ def initialize(args=None, distributed_port=distributed_port, dist_init_required=dist_init_required) + ##TODO: combine reuse mpu as mesh device and vice versa # Set config using config_params for backwards compat if config is None and config_params is not None: config = config_params + mesh_device = None + if mesh_param: + logger.info(f"mesh_param to Initialize mesh device: {mesh_param}") + mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel")) + #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device + elif config is not None: + if "sequence_parallel_size" in config and "data_parallel_size" in config: + logger.info(f"config to Initialize mesh device: {config}") + mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \ + ("data_parallel", "sequence_parallel")) + # Check for deepscale_config for backwards compat if hasattr(args, "deepscale_config") and args.deepscale_config is not None: logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") @@ -162,9 +175,8 @@ def initialize(args=None, assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call" config = args.deepspeed_config assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file" - if not isinstance(model, PipelineModule): - config_class = DeepSpeedConfig(config, mpu) + config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device) if config_class.hybrid_engine.enabled: engine = DeepSpeedHybridEngine(args=args, model=model, @@ -188,6 +200,7 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, + mesh_device=mesh_device, config_class=config_class) else: assert mpu is None, "mpu must be None with pipeline parallelism" @@ -208,7 +221,12 @@ def initialize(args=None, # Restore zero.Init context if necessary zero.partition_parameters.restore_init_context() - return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler] + return_items = [ + engine, + engine.optimizer, + engine.training_dataloader, + engine.lr_scheduler, + ] return tuple(return_items) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py old mode 100644 new mode 100755 index 85b7fab2c548..2895e0f2e011 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -600,6 +600,21 @@ def get_all_ranks_from_group(group=None): return group_ranks +def initialize_mesh_device(mesh_shape, mesh_dim_names): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + mesh_device = None + if hasattr(cdb, 'init_device_mesh'): + utils.logger.info(f"Initializing mesh device with backend {cdb.name} \ + with shape {mesh_shape} and dim names {mesh_dim_names}") + mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names) + else: + if get_rank() == 0: + utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization") + return mesh_device + + # Main DeepSpeed Comms. public API. def init_distributed(dist_backend=None, auto_mpi_discovery=True, diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py old mode 100644 new mode 100755 index 83754e98f033..ed2645d415c4 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -386,6 +386,14 @@ def _reduce_op(self, op): op = torch.distributed.ReduceOp.BXOR return op + def init_device_mesh(self, mesh_shape, mesh_dim_names): + if not required_torch_version(min_version=2.2): + raise RuntimeError(f"Current torch version does not have device mesh" + f"api (torch.__version__: {torch.__version__})") + return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(), + mesh_shape, + mesh_dim_names=mesh_dim_names) + # This will become a light-weight wrapper around torch.distributed functions # TODO: create some example to show how this wrapper can help profile communication diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b49b4a8b6086..8be2f7ac4055 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -705,7 +705,7 @@ def write_config(self, filename): class DeepSpeedConfig(object): - def __init__(self, config: Union[str, dict], mpu=None): + def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None): super(DeepSpeedConfig, self).__init__() if isinstance(config, dict): self._param_dict = config @@ -721,14 +721,16 @@ def __init__(self, config: Union[str, dict], mpu=None): ) try: self.global_rank = dist.get_rank() - if mpu is None: - self.world_size = dist.get_world_size() - else: + if mpu is not None: self.world_size = mpu.get_data_parallel_world_size() + elif mesh_device is not None: + self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) + else: + self.world_size = dist.get_world_size() except: self.global_rank = 0 self.world_size = 1 - + logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}") # If elastic-mode enabled, update compute + update _param_dict self.elasticity_enabled = elasticity_enabled(self._param_dict) if self.elasticity_enabled: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py old mode 100644 new mode 100755 index 27d294b3ae01..61e6da2663cf --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -194,6 +194,7 @@ def __init__(self, collate_fn=None, config=None, config_class=None, + mesh_device=None, dont_change_device=False): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device @@ -233,10 +234,14 @@ def __init__(self, self._is_gradient_accumulation_boundary = None self.scale_wrt_gas = None self.losses = None + self.mesh_device = mesh_device # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) + if self.mesh_device: + groups.mesh_device = self.mesh_device + self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() @@ -615,6 +620,9 @@ def random_ltd_initialize(self): raise ValueError(f'not yet support') #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) + def get_sequence_parallel_group(self): + return self.seq_parallel_group + def wall_clock_breakdown(self): return self._config.wall_clock_breakdown @@ -1187,6 +1195,7 @@ def _configure_distributed_model(self, model): self.sequence_parallel_size = groups._get_sequence_parallel_world_size() if self.sequence_parallel_size > 1: self.communication_data_type = self._config.seq_parallel_communication_data_type + self.seq_parallel_group = groups._get_sequence_parallel_group() if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() diff --git a/deepspeed/sequence/cross_entropy.py b/deepspeed/sequence/cross_entropy.py new file mode 100644 index 000000000000..baa7bc1ea7a8 --- /dev/null +++ b/deepspeed/sequence/cross_entropy.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +import deepspeed.comm as dist + + +class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_seq_parallel_logits, target, sp_group): + # vocab_seq_parallel_logits: [S/P, B, V] + # target: [S/P, B] + # return: [S, B] + + # Need softmax for backward + softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) + ctx.vocab_size = vocab_seq_parallel_logits.size(2) + loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none') + + sp_world_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + ctx.sp_world_size = sp_world_size + ctx.sp_rank = sp_rank + ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size + batch_size = vocab_seq_parallel_logits.size(1) + + loss_all = torch.empty(ctx.seqlen, + batch_size, + dtype=vocab_seq_parallel_logits.dtype, + device=vocab_seq_parallel_logits.device) + dist.all_gather_into_tensor(loss_all, loss, group=sp_group) + + ctx.save_for_backward(softmax, target) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + softmax, target = ctx.saved_tensors + + step_seqlen = ctx.seqlen // ctx.sp_world_size + sp_rank = ctx.sp_rank + grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :] + + grad_input = softmax + grad_2d = grad_input.view(-1, ctx.vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + grad_2d[arange_1d, target.view(-1)] -= 1 + grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + + return grad_input, None, None, None + + +def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group): + return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group) diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py old mode 100644 new mode 100755 index c49f4520e16e..9dd288ef46db --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -30,6 +30,7 @@ from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size from deepspeed.utils.exceptions import DeprecatedException from deepspeed.accelerator import get_accelerator + # Expert parallel group that the current rank belongs to. _EXPERT_PARALLEL_GROUP = {} # Expert data parallel group that the current rank belongs to. @@ -47,6 +48,8 @@ _DATA_PARALLEL_GROUP = None +mesh_device = None + # Deprecated groups initialize function. def initialize(ep_size=1, mpu=None): @@ -398,8 +401,11 @@ def _get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" assert dist.is_initialized(), 'dist is not initialized' global mpu + if mesh_device is not None: + return mesh_device.get_group(mesh_dim="data_parallel") if mpu is not None: return mpu.get_data_parallel_group() + # Return the clone of dist world group return _clone_world_group() @@ -442,6 +448,8 @@ def _get_expert_data_parallel_rank(group_name): def _get_data_parallel_world_size(): """Return world size for the data parallel group.""" + if mesh_device is not None: + return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) global mpu if mpu is not None: return mpu.get_data_parallel_world_size() @@ -464,6 +472,8 @@ def _get_data_parallel_rank(): def _get_sequence_parallel_world_size(): """Return world size for the model parallel group.""" global mpu + if mesh_device is not None: + return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel")) if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'): return mpu.get_sequence_parallel_world_size() return 1 @@ -479,9 +489,11 @@ def _get_sequence_parallel_rank(): def _get_sequence_parallel_group(): global mpu - if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'): - return mpu.get_sequence_parallel_group() - return None + if mpu is None or not hasattr(mpu, 'get_sequence_parallel_group'): + if mesh_device is None: + raise KeyError("No sequence parallel group found") + return mesh_device.get_group(mesh_dim="sequence_parallel") + return mpu.get_sequence_parallel_group() def _get_sequence_data_parallel_world_size(): diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py new file mode 100644 index 000000000000..915c89e0b00a --- /dev/null +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed.comm as dist +from deepspeed import initialize +from transformers import AutoModel +from unit.common import DistributedTest +from deepspeed.sequence.layer import _SeqAllToAll +from unit.util import skip_on_arch + + +#Use mesh device to create data and sequence parallel group +class TestUlyssesUtils(DistributedTest): + world_size = 4 + + def test_mesh_device_creation(self) -> None: + skip_on_arch(min_arch=8) + model = AutoModel.from_pretrained('bert-base-uncased') + sp_size = 2 + dp_size = 2 + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size + }, + ) + assert ds_engine.seq_parallel_group is not None + assert ds_engine.data_parallel_group is not None + assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size + assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size + assert dist.get_world_size() == sp_size * dp_size + + +#Sweep b,s,h,d to test all2all consistency +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [16, 32]) +class TestUlyssesAll2All(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + skip_on_arch(min_arch=8) + model = AutoModel.from_pretrained('bert-base-uncased') + ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2)) + #4D tensor : b,s,h,d or s,b,h,d + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + scatter_idx = 2 + batch_dim_idx = 0 + outputs = [] + seq_dims = [0] #seq first API + #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed + ## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808 + for seq_dim in seq_dims: + gather_idx = seq_dim + #first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + #No op + # second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs are the same as input + for i in range(1, len(outputs)): + assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"