-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Long sequence parallelism (Ulysses) integration with HuggingFace (#5774)
This PR enhances capabilities of [DeepSpeed long sequence (context) parallelism (aka DS Ulysses)](https://dl.acm.org/doi/10.1145/3662158.3662806) with support for HuggingFace (and by extension other frameworks) models. With HF integration, users can use sequence parallelism for model pre/mid/post-training, finetuning etc. Usage requires both _torch >=2.2.2 and flash-attention_. ZeRO-1 and 2 are supported, ZeRO-3 and SPDA support in progress. Corresponding PR in HF is [PR32305](huggingface/transformers#32305). --------- Co-authored-by: Logan Adams <[email protected]>
- Loading branch information
1 parent
b65ea50
commit 8b191d7
Showing
8 changed files
with
212 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
|
||
import deepspeed.comm as dist | ||
|
||
|
||
class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, vocab_seq_parallel_logits, target, sp_group): | ||
# vocab_seq_parallel_logits: [S/P, B, V] | ||
# target: [S/P, B] | ||
# return: [S, B] | ||
|
||
# Need softmax for backward | ||
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) | ||
ctx.vocab_size = vocab_seq_parallel_logits.size(2) | ||
loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none') | ||
|
||
sp_world_size = dist.get_world_size(sp_group) | ||
sp_rank = dist.get_rank(sp_group) | ||
ctx.sp_world_size = sp_world_size | ||
ctx.sp_rank = sp_rank | ||
ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size | ||
batch_size = vocab_seq_parallel_logits.size(1) | ||
|
||
loss_all = torch.empty(ctx.seqlen, | ||
batch_size, | ||
dtype=vocab_seq_parallel_logits.dtype, | ||
device=vocab_seq_parallel_logits.device) | ||
dist.all_gather_into_tensor(loss_all, loss, group=sp_group) | ||
|
||
ctx.save_for_backward(softmax, target) | ||
|
||
return loss_all | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
softmax, target = ctx.saved_tensors | ||
|
||
step_seqlen = ctx.seqlen // ctx.sp_world_size | ||
sp_rank = ctx.sp_rank | ||
grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :] | ||
|
||
grad_input = softmax | ||
grad_2d = grad_input.view(-1, ctx.vocab_size) | ||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) | ||
|
||
grad_2d[arange_1d, target.view(-1)] -= 1 | ||
grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) | ||
|
||
return grad_input, None, None, None | ||
|
||
|
||
def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group): | ||
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import pytest | ||
import torch | ||
import deepspeed.comm as dist | ||
from deepspeed import initialize | ||
from transformers import AutoModel | ||
from unit.common import DistributedTest | ||
from deepspeed.sequence.layer import _SeqAllToAll | ||
from unit.util import skip_on_arch | ||
|
||
|
||
#Use mesh device to create data and sequence parallel group | ||
class TestUlyssesUtils(DistributedTest): | ||
world_size = 4 | ||
|
||
def test_mesh_device_creation(self) -> None: | ||
skip_on_arch(min_arch=8) | ||
model = AutoModel.from_pretrained('bert-base-uncased') | ||
sp_size = 2 | ||
dp_size = 2 | ||
ds_engine, _, _, _ = initialize( | ||
model=model, | ||
config_params={ | ||
"train_batch_size": 8, | ||
"data_parallel_size": dp_size, | ||
"sequence_parallel_size": sp_size | ||
}, | ||
) | ||
assert ds_engine.seq_parallel_group is not None | ||
assert ds_engine.data_parallel_group is not None | ||
assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size | ||
assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size | ||
assert dist.get_world_size() == sp_size * dp_size | ||
|
||
|
||
#Sweep b,s,h,d to test all2all consistency | ||
@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension | ||
@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension | ||
@pytest.mark.parametrize("num_heads", [4, 8]) | ||
@pytest.mark.parametrize("head_dim", [16, 32]) | ||
class TestUlyssesAll2All(DistributedTest): | ||
world_size = 4 | ||
|
||
def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: | ||
skip_on_arch(min_arch=8) | ||
model = AutoModel.from_pretrained('bert-base-uncased') | ||
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2)) | ||
#4D tensor : b,s,h,d or s,b,h,d | ||
input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) | ||
scatter_idx = 2 | ||
batch_dim_idx = 0 | ||
outputs = [] | ||
seq_dims = [0] #seq first API | ||
#TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed | ||
## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808 | ||
for seq_dim in seq_dims: | ||
gather_idx = seq_dim | ||
#first all2all: sequence parallel to head parallel | ||
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, | ||
batch_dim_idx) | ||
|
||
#No op | ||
# second all2all: head parallel to sequence parallel | ||
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, | ||
batch_dim_idx) | ||
print( | ||
f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' | ||
) | ||
outputs.append(h2s_tensor) | ||
|
||
# Check outputs are the same as input | ||
for i in range(1, len(outputs)): | ||
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" |