Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long sequence parallelism (Ulysses) integration with HuggingFace #5774

Merged
merged 30 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
97e8b41
Use PyTorch mesh tensor to create grps
samadejacobs Apr 9, 2024
e4a42c0
Add sequence parallel grp and test example scripts
samadejacobs Apr 19, 2024
b2e09aa
Remove debug statement
samadejacobs Apr 22, 2024
8b9a25f
Set seq parallel group in DS engine
samadejacobs Apr 28, 2024
02b97db
Pass mesh device to DS config
samadejacobs Apr 30, 2024
c9f2505
Add sp CE
samadejacobs May 28, 2024
69f16c3
Remove debug statment
samadejacobs Jun 4, 2024
d79107a
Set flash attn as required and required torch version
samadejacobs Jul 15, 2024
dcbf869
Clean up debug statement, clarify log msg
samadejacobs Jul 16, 2024
a404452
Remove deprecated test file
samadejacobs Jul 16, 2024
a211078
Remove deprecated test bash script
samadejacobs Jul 16, 2024
453cbee
Merge remote-tracking branch 'origin/master' into uly-hf
samadejacobs Jul 17, 2024
c5ba526
Code clean up, fix where batch first in all2all
samadejacobs Jul 20, 2024
00b05f8
Delete tests/small_model_debugging/test_ulysses_hf.py
samadejacobs Jul 22, 2024
2f4bd5a
Fix sp in ds engine to align with new (latest) HF transformer flash_attn
samadejacobs Jul 29, 2024
fc119ca
Warn is backend does not support mesh device, fail fast if seq parall…
samadejacobs Jul 31, 2024
d4b19d1
Repalce cuda with get_accelerator, set minimum torch version for devi…
samadejacobs Jul 31, 2024
073c381
Merge branch 'master' into uly-hf
loadams Jul 31, 2024
686a184
Add unit tests
samadejacobs Aug 2, 2024
1a22cd1
Merge branch 'uly-hf' of https://github.com/samadejacobs/DeepSpeed in…
samadejacobs Aug 2, 2024
fc9330d
Add support to (optionally) set device mesh config (dp_size, sp_size)…
samadejacobs Aug 7, 2024
7a59d3c
Merge branch 'master' into uly-hf
samadejacobs Aug 7, 2024
aaa79eb
Merge branch 'master' into uly-hf
samadejacobs Aug 8, 2024
43806b4
Add (now/newly required) batch_dim_idx to unit test
samadejacobs Aug 9, 2024
f4bd5f4
Fix formating
samadejacobs Aug 13, 2024
7a64336
Merge branch 'master' into uly-hf
loadams Aug 14, 2024
14f51ad
Merge branch 'master' into uly-hf
loadams Aug 15, 2024
76c67c0
Merge branch 'master' into uly-hf
loadams Aug 19, 2024
cb7c20e
Update test_ulysses.py
samadejacobs Aug 20, 2024
513d479
Merge branch 'master' into uly-hf
samadejacobs Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def initialize(args=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
mesh_param=None,
config_params=None):
"""Initialize the DeepSpeed Engine.

Expand Down Expand Up @@ -143,11 +144,23 @@ def initialize(args=None,
dist.init_distributed(dist_backend=dist_backend,
distributed_port=distributed_port,
dist_init_required=dist_init_required)


##TODO: combine reuse mpu as mesh device and vice versa
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params

mesh_device = None
if mesh_param:
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
elif config is not None:
if "sequence_parallel_size" in config and "data_parallel_size" in config:
logger.info(f"config to Initialize mesh device: {config}")
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
("data_parallel", "sequence_parallel"))

# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
Expand All @@ -162,9 +175,8 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"

if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu)
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
Expand All @@ -188,6 +200,7 @@ def initialize(args=None,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
mesh_device=mesh_device,
config_class=config_class)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
Expand All @@ -208,7 +221,12 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return_items = [
engine,
engine.optimizer,
engine.training_dataloader,
engine.lr_scheduler,
]
return tuple(return_items)


Expand Down
15 changes: 15 additions & 0 deletions deepspeed/comm/comm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,21 @@ def get_all_ranks_from_group(group=None):
return group_ranks


def initialize_mesh_device(mesh_shape, mesh_dim_names):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
mesh_device = None
if hasattr(cdb, 'init_device_mesh'):
samadejacobs marked this conversation as resolved.
Show resolved Hide resolved
utils.logger.info(f"Initializing mesh device with backend {cdb.name} \
with shape {mesh_shape} and dim names {mesh_dim_names}")
mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names)
else:
if get_rank() == 0:
utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization")
return mesh_device


# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/comm/torch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ def _reduce_op(self, op):
op = torch.distributed.ReduceOp.BXOR
return op

def init_device_mesh(self, mesh_shape, mesh_dim_names):
if not required_torch_version(min_version=2.2):
raise RuntimeError(f"Current torch version does not have device mesh"
f"api (torch.__version__: {torch.__version__})")
return torch.distributed.device_mesh.init_device_mesh(
get_accelerator().current_device_name(), mesh_shape, mesh_dim_names=mesh_dim_names)


# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication
Expand Down
12 changes: 7 additions & 5 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def write_config(self, filename):

class DeepSpeedConfig(object):

def __init__(self, config: Union[str, dict], mpu=None):
def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
super(DeepSpeedConfig, self).__init__()
if isinstance(config, dict):
self._param_dict = config
Expand All @@ -721,14 +721,16 @@ def __init__(self, config: Union[str, dict], mpu=None):
)
try:
self.global_rank = dist.get_rank()
if mpu is None:
self.world_size = dist.get_world_size()
else:
if mpu is not None:
self.world_size = mpu.get_data_parallel_world_size()
elif mesh_device is not None:
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
else:
self.world_size = dist.get_world_size()
except:
self.global_rank = 0
self.world_size = 1

logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}")
# If elastic-mode enabled, update compute + update _param_dict
self.elasticity_enabled = elasticity_enabled(self._param_dict)
if self.elasticity_enabled:
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/runtime/engine.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self,
collate_fn=None,
config=None,
config_class=None,
mesh_device=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
Expand Down Expand Up @@ -231,10 +232,14 @@ def __init__(self,
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
self.losses = None
self.mesh_device = mesh_device

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

if self.mesh_device:
groups.mesh_device = self.mesh_device

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand Down Expand Up @@ -574,6 +579,9 @@ def random_ltd_initialize(self):
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)

def get_sequence_parallel_group(self):
return self.seq_parallel_group

def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown

Expand Down Expand Up @@ -1143,6 +1151,7 @@ def _configure_distributed_model(self, model):
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type
self.seq_parallel_group = groups._get_sequence_parallel_group()

if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
Expand Down
60 changes: 60 additions & 0 deletions deepspeed/sequence/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

import deepspeed.comm as dist


class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):

@staticmethod
def forward(ctx, vocab_seq_parallel_logits, target, sp_group):
# vocab_seq_parallel_logits: [S/P, B, V]
# target: [S/P, B]
# return: [S, B]

# Need softmax for backward
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
ctx.vocab_size = vocab_seq_parallel_logits.size(2)
loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none')

sp_world_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
ctx.sp_world_size = sp_world_size
ctx.sp_rank = sp_rank
ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size
batch_size = vocab_seq_parallel_logits.size(1)

loss_all = torch.empty(ctx.seqlen,
batch_size,
dtype=vocab_seq_parallel_logits.dtype,
device=vocab_seq_parallel_logits.device)
dist.all_gather_into_tensor(loss_all, loss, group=sp_group)

ctx.save_for_backward(softmax, target)

return loss_all

@staticmethod
def backward(ctx, grad_output):
softmax, target = ctx.saved_tensors

step_seqlen = ctx.seqlen // ctx.sp_world_size
sp_rank = ctx.sp_rank
grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :]

grad_input = softmax
grad_2d = grad_input.view(-1, ctx.vocab_size)
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)

grad_2d[arange_1d, target.view(-1)] -= 1
grad_input.mul_(grad_output_part.unsqueeze(dim=-1))

return grad_input, None, None, None


def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group):
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group)
18 changes: 15 additions & 3 deletions deepspeed/utils/groups.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
from deepspeed.utils.exceptions import DeprecatedException
from deepspeed.accelerator import get_accelerator

# Expert parallel group that the current rank belongs to.
_EXPERT_PARALLEL_GROUP = {}
# Expert data parallel group that the current rank belongs to.
Expand All @@ -47,6 +48,8 @@

_DATA_PARALLEL_GROUP = None

mesh_device = None


# Deprecated groups initialize function.
def initialize(ep_size=1, mpu=None):
Expand Down Expand Up @@ -398,8 +401,11 @@ def _get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert dist.is_initialized(), 'dist is not initialized'
global mpu
if mesh_device is not None:
return mesh_device.get_group(mesh_dim="data_parallel")
if mpu is not None:
return mpu.get_data_parallel_group()

# Return the clone of dist world group
return _clone_world_group()

Expand Down Expand Up @@ -442,6 +448,8 @@ def _get_expert_data_parallel_rank(group_name):

def _get_data_parallel_world_size():
"""Return world size for the data parallel group."""
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
global mpu
if mpu is not None:
return mpu.get_data_parallel_world_size()
Expand All @@ -464,6 +472,8 @@ def _get_data_parallel_rank():
def _get_sequence_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
return mpu.get_sequence_parallel_world_size()
return 1
Expand All @@ -479,9 +489,11 @@ def _get_sequence_parallel_rank():

def _get_sequence_parallel_group():
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
return mpu.get_sequence_parallel_group()
return None
if mpu is None or not hasattr(mpu, 'get_sequence_parallel_group'):
if mesh_device is None:
raise KeyError("No sequence parallel group found")
return mesh_device.get_group(mesh_dim="sequence_parallel")
return mpu.get_sequence_parallel_group()


def _get_sequence_data_parallel_world_size():
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import torch.distributed as dist
from deepspeed import initialize
from transformers import AutoModel
from unit.common import DistributedTest
from deepspeed.sequence.layer import _SeqAllToAll


#Use of torch.distributed mesh device to create data and sequence parallel group
class TestUlyssesUtils(DistributedTest):
world_size = 4
def test_mesh_device_creation(self)->None:
model = AutoModel.from_pretrained('bert-base-uncased')
sp_size = 2
dp_size = 2
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8,
"data_parallel_size": dp_size, "sequence_parallel_size": sp_size},)
assert ds_engine.seq_parallel_group is not None
assert ds_engine.data_parallel_group is not None
assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size
assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size
assert dist.get_world_size() == sp_size * dp_size


#Sweep b,s,h,d to test all2all consistency
@pytest.mark.parametrize("d0", [2,4]) #batch or sequence dimension
@pytest.mark.parametrize("d1", [4,8]) #batch or sequence dimension
@pytest.mark.parametrize("num_heads", [4,8])
@pytest.mark.parametrize("head_dim", [16,32])
class TestUlyssesAll2All(DistributedTest):
world_size = 4
def test_alltoall_output_consistency(self,d0:int,d1: int,
head_dim: int, num_heads:int)->None:
model = AutoModel.from_pretrained('bert-base-uncased')
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8},
mesh_param=(2, 2))
#4D tensor : b,s,h,d or s,b,h,d
input_tensor = torch.randn(d0,d1,num_heads,head_dim, device=ds_engine.device)
scatter_idx = 2
batch_dim_idx=0
outputs = []
seq_dims = [0] #seq first API
#TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
for seq_dim in seq_dims:
gather_idx = seq_dim
#first all2all: sequence parallel to head parallel
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group,
input_tensor, scatter_idx,gather_idx, batch_dim_idx)

#No op
# second all2all: head parallel to sequence parallel
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group,
s2h_tensor, gather_idx,scatter_idx, batch_dim_idx)
print(f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}')
outputs.append(h2s_tensor)

# Check outputs are the same as input
for i in range(1, len(outputs)):
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"


Loading