Skip to content

Commit

Permalink
sequence parallel for uneven heads
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Aug 21, 2024
1 parent 8b191d7 commit 1e297d3
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def set_n_embd(num):

def get_num_kv_heads():
global num_kv_heads
return num_kv_heads
if 'num_kv_heads' in globals():
return num_kv_heads
return None


def get_num_attention_heads():
Expand Down
63 changes: 63 additions & 0 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups


def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
Expand Down Expand Up @@ -38,8 +40,69 @@ def post_func(input):
return post_func


def uneven_heads_all2all(input, scatter_idx, gather_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
if not (scatter_idx < 2):
input_splits = get_shard_size_list(input.shape[scatter_idx], seq_world_size)
input = input.transpose(0, scatter_idx).contiguous()
local_heads = input_splits[groups._get_sequence_parallel_rank()]
output_splits = [local_heads] * seq_world_size

output_shape = [seq_world_size * local_heads] + list(input.shape[1:])
output = torch.empty(output_shape, device=input.device, dtype=input.dtype)

dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
input_split_sizes=input_splits,group=group)

###[seq_ws*local_heads, ..] to [seq_ws, local_heads, ..]
output = output.view(seq_world_size, local_heads, *output.shape[1:])
###[seq_ws,local_heads,b,seq_len,..] to [seq_ws,seq_len,b,local_heads,..]
output = output.transpose(1, 3).contiguous()
###[seq_ws*local_seq_len, b, local_heads,...]
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
if scatter_idx < 2:
input = input.view(input.shape[0], input.shape[1], -1)
input = input.reshape(input.shape[0] * input.shape[2], input.shape[1])
local_seq_len_with_heads = int(input.shape[0] / seq_world_size)
input_splits = [local_seq_len_with_heads] * seq_world_size
output_splits = get_shard_size_list(get_num_kv_heads(), seq_world_size)

coeff = int(local_seq_len_with_heads / output_splits[groups._get_sequence_parallel_rank()])
#uneven seq_world_size coeff , local_heads/total_heads.
heads_scale_coeff = get_num_kv_heads() / get_shard_size_list(
get_num_kv_heads(), seq_world_size)[groups._get_sequence_parallel_rank()]
output_splits = [i * coeff for i in output_splits]
output_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
total_seq_len = int(inp_shape[gather_idx] * heads_scale_coeff)
output = torch.empty(output_d1_size, input.shape[1], device=input.device, dtype=input.dtype)

dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
input_split_sizes=input_splits,group=group)

inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
output2_shape= inp_shape[: gather_idx] + \
[total_seq_len,] + \
inp_shape[gather_idx + 1:]
output = output.reshape(output2_shape)
return output


def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)

num_heads = input.shape[2]

if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
# assume here that the number of heads for q is consistent with kv
# or require additional logic
if get_num_kv_heads() is None:
# set heads at first call by ``num_total_heads``.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, group)

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/utils/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def _get_sequence_parallel_rank():
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
return mpu.get_sequence_parallel_rank()
if mesh_device is not None:
return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel"))
return 0


Expand Down
41 changes: 41 additions & 0 deletions tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from unit.common import DistributedTest
from deepspeed.sequence.layer import _SeqAllToAll
from unit.util import skip_on_arch
from unit.simple_model import *


#Use mesh device to create data and sequence parallel group
Expand Down Expand Up @@ -75,3 +76,43 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_
# Check outputs are the same as input
for i in range(1, len(outputs)):
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"


@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
@pytest.mark.parametrize("num_heads", [3, 7])
@pytest.mark.parametrize("head_dim", [16, 32])
class TestUlyssesAll2All_odd(DistributedTest):
world_size = 4

def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
skip_on_arch(min_arch=8)
hidden_dim = 10
model = SimpleModel(hidden_dim)
ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
#4D tensor : b,s,h,d or s,b,h,d
input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
scatter_idx = 2
batch_dim_idx = 0
outputs = []
seq_dims = [0] #seq first API
#TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
for seq_dim in seq_dims:
gather_idx = seq_dim
#first all2all: sequence parallel to head parallel
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
batch_dim_idx)

#No op
# second all2all: head parallel to sequence parallel
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
batch_dim_idx)
print(
f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
)
outputs.append(h2s_tensor)

# Check outputs are the same as input
for i in range(1, len(outputs)):
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"

0 comments on commit 1e297d3

Please sign in to comment.