Skip to content

Commit

Permalink
Long sequence parallelism (Ulysses) integration with HuggingFace (#5774)
Browse files Browse the repository at this point in the history
This PR enhances capabilities of [DeepSpeed long sequence (context)
parallelism (aka DS
Ulysses)](https://dl.acm.org/doi/10.1145/3662158.3662806) with support
for HuggingFace (and by extension other frameworks) models. With HF
integration, users can use sequence parallelism for model
pre/mid/post-training, finetuning etc. Usage requires both _torch
>=2.2.2 and flash-attention_. ZeRO-1 and 2 are supported, ZeRO-3 and
SPDA support in progress. Corresponding PR in HF is
[PR32305](huggingface/transformers#32305).

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
samadejacobs and loadams authored Aug 21, 2024
1 parent b65ea50 commit 8b191d7
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 11 deletions.
24 changes: 21 additions & 3 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def initialize(args=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
mesh_param=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
Expand Down Expand Up @@ -144,10 +145,22 @@ def initialize(args=None,
distributed_port=distributed_port,
dist_init_required=dist_init_required)

##TODO: combine reuse mpu as mesh device and vice versa
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params

mesh_device = None
if mesh_param:
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
elif config is not None:
if "sequence_parallel_size" in config and "data_parallel_size" in config:
logger.info(f"config to Initialize mesh device: {config}")
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
("data_parallel", "sequence_parallel"))

# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
Expand All @@ -162,9 +175,8 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"

if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu)
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
Expand All @@ -188,6 +200,7 @@ def initialize(args=None,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
mesh_device=mesh_device,
config_class=config_class)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
Expand All @@ -208,7 +221,12 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return_items = [
engine,
engine.optimizer,
engine.training_dataloader,
engine.lr_scheduler,
]
return tuple(return_items)


Expand Down
15 changes: 15 additions & 0 deletions deepspeed/comm/comm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,21 @@ def get_all_ranks_from_group(group=None):
return group_ranks


def initialize_mesh_device(mesh_shape, mesh_dim_names):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
mesh_device = None
if hasattr(cdb, 'init_device_mesh'):
utils.logger.info(f"Initializing mesh device with backend {cdb.name} \
with shape {mesh_shape} and dim names {mesh_dim_names}")
mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names)
else:
if get_rank() == 0:
utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization")
return mesh_device


# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/comm/torch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ def _reduce_op(self, op):
op = torch.distributed.ReduceOp.BXOR
return op

def init_device_mesh(self, mesh_shape, mesh_dim_names):
if not required_torch_version(min_version=2.2):
raise RuntimeError(f"Current torch version does not have device mesh"
f"api (torch.__version__: {torch.__version__})")
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)


# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication
Expand Down
12 changes: 7 additions & 5 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def write_config(self, filename):

class DeepSpeedConfig(object):

def __init__(self, config: Union[str, dict], mpu=None):
def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
super(DeepSpeedConfig, self).__init__()
if isinstance(config, dict):
self._param_dict = config
Expand All @@ -721,14 +721,16 @@ def __init__(self, config: Union[str, dict], mpu=None):
)
try:
self.global_rank = dist.get_rank()
if mpu is None:
self.world_size = dist.get_world_size()
else:
if mpu is not None:
self.world_size = mpu.get_data_parallel_world_size()
elif mesh_device is not None:
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
else:
self.world_size = dist.get_world_size()
except:
self.global_rank = 0
self.world_size = 1

logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}")
# If elastic-mode enabled, update compute + update _param_dict
self.elasticity_enabled = elasticity_enabled(self._param_dict)
if self.elasticity_enabled:
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/runtime/engine.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self,
collate_fn=None,
config=None,
config_class=None,
mesh_device=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
Expand Down Expand Up @@ -233,10 +234,14 @@ def __init__(self,
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
self.losses = None
self.mesh_device = mesh_device

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

if self.mesh_device:
groups.mesh_device = self.mesh_device

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand Down Expand Up @@ -615,6 +620,9 @@ def random_ltd_initialize(self):
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)

def get_sequence_parallel_group(self):
return self.seq_parallel_group

def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown

Expand Down Expand Up @@ -1187,6 +1195,7 @@ def _configure_distributed_model(self, model):
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type
self.seq_parallel_group = groups._get_sequence_parallel_group()

if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
Expand Down
60 changes: 60 additions & 0 deletions deepspeed/sequence/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

import deepspeed.comm as dist


class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):

@staticmethod
def forward(ctx, vocab_seq_parallel_logits, target, sp_group):
# vocab_seq_parallel_logits: [S/P, B, V]
# target: [S/P, B]
# return: [S, B]

# Need softmax for backward
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
ctx.vocab_size = vocab_seq_parallel_logits.size(2)
loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none')

sp_world_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
ctx.sp_world_size = sp_world_size
ctx.sp_rank = sp_rank
ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size
batch_size = vocab_seq_parallel_logits.size(1)

loss_all = torch.empty(ctx.seqlen,
batch_size,
dtype=vocab_seq_parallel_logits.dtype,
device=vocab_seq_parallel_logits.device)
dist.all_gather_into_tensor(loss_all, loss, group=sp_group)

ctx.save_for_backward(softmax, target)

return loss_all

@staticmethod
def backward(ctx, grad_output):
softmax, target = ctx.saved_tensors

step_seqlen = ctx.seqlen // ctx.sp_world_size
sp_rank = ctx.sp_rank
grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :]

grad_input = softmax
grad_2d = grad_input.view(-1, ctx.vocab_size)
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)

grad_2d[arange_1d, target.view(-1)] -= 1
grad_input.mul_(grad_output_part.unsqueeze(dim=-1))

return grad_input, None, None, None


def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group):
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group)
18 changes: 15 additions & 3 deletions deepspeed/utils/groups.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
from deepspeed.utils.exceptions import DeprecatedException
from deepspeed.accelerator import get_accelerator

# Expert parallel group that the current rank belongs to.
_EXPERT_PARALLEL_GROUP = {}
# Expert data parallel group that the current rank belongs to.
Expand All @@ -47,6 +48,8 @@

_DATA_PARALLEL_GROUP = None

mesh_device = None


# Deprecated groups initialize function.
def initialize(ep_size=1, mpu=None):
Expand Down Expand Up @@ -398,8 +401,11 @@ def _get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert dist.is_initialized(), 'dist is not initialized'
global mpu
if mesh_device is not None:
return mesh_device.get_group(mesh_dim="data_parallel")
if mpu is not None:
return mpu.get_data_parallel_group()

# Return the clone of dist world group
return _clone_world_group()

Expand Down Expand Up @@ -442,6 +448,8 @@ def _get_expert_data_parallel_rank(group_name):

def _get_data_parallel_world_size():
"""Return world size for the data parallel group."""
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
global mpu
if mpu is not None:
return mpu.get_data_parallel_world_size()
Expand All @@ -464,6 +472,8 @@ def _get_data_parallel_rank():
def _get_sequence_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
return mpu.get_sequence_parallel_world_size()
return 1
Expand All @@ -479,9 +489,11 @@ def _get_sequence_parallel_rank():

def _get_sequence_parallel_group():
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
return mpu.get_sequence_parallel_group()
return None
if mpu is None or not hasattr(mpu, 'get_sequence_parallel_group'):
if mesh_device is None:
raise KeyError("No sequence parallel group found")
return mesh_device.get_group(mesh_dim="sequence_parallel")
return mpu.get_sequence_parallel_group()


def _get_sequence_data_parallel_world_size():
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed.comm as dist
from deepspeed import initialize
from transformers import AutoModel
from unit.common import DistributedTest
from deepspeed.sequence.layer import _SeqAllToAll
from unit.util import skip_on_arch


#Use mesh device to create data and sequence parallel group
class TestUlyssesUtils(DistributedTest):
world_size = 4

def test_mesh_device_creation(self) -> None:
skip_on_arch(min_arch=8)
model = AutoModel.from_pretrained('bert-base-uncased')
sp_size = 2
dp_size = 2
ds_engine, _, _, _ = initialize(
model=model,
config_params={
"train_batch_size": 8,
"data_parallel_size": dp_size,
"sequence_parallel_size": sp_size
},
)
assert ds_engine.seq_parallel_group is not None
assert ds_engine.data_parallel_group is not None
assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size
assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size
assert dist.get_world_size() == sp_size * dp_size


#Sweep b,s,h,d to test all2all consistency
@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
@pytest.mark.parametrize("num_heads", [4, 8])
@pytest.mark.parametrize("head_dim", [16, 32])
class TestUlyssesAll2All(DistributedTest):
world_size = 4

def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
skip_on_arch(min_arch=8)
model = AutoModel.from_pretrained('bert-base-uncased')
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
#4D tensor : b,s,h,d or s,b,h,d
input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
scatter_idx = 2
batch_dim_idx = 0
outputs = []
seq_dims = [0] #seq first API
#TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
for seq_dim in seq_dims:
gather_idx = seq_dim
#first all2all: sequence parallel to head parallel
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
batch_dim_idx)

#No op
# second all2all: head parallel to sequence parallel
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
batch_dim_idx)
print(
f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
)
outputs.append(h2s_tensor)

# Check outputs are the same as input
for i in range(1, len(outputs)):
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"

0 comments on commit 8b191d7

Please sign in to comment.