From 1e297d3eb30b3bd7cf6f3910dbaebc9ba2ded243 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 21 Aug 2024 11:46:09 +0000 Subject: [PATCH] sequence parallel for uneven heads --- deepspeed/module_inject/tp_shard.py | 4 +- deepspeed/sequence/layer.py | 63 +++++++++++++++++++ deepspeed/utils/groups.py | 2 + .../unit/sequence_parallelism/test_ulysses.py | 41 ++++++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 6758c7a657f63..57be0c7938560 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -24,7 +24,9 @@ def set_n_embd(num): def get_num_kv_heads(): global num_kv_heads - return num_kv_heads + if 'num_kv_heads' in globals(): + return num_kv_heads + return None def get_num_attention_heads(): diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e809fe1118b51..a1de668feccae 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,6 +10,8 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads +from deepspeed.utils import groups def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): @@ -38,8 +40,69 @@ def post_func(input): return post_func +def uneven_heads_all2all(input, scatter_idx, gather_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + if not (scatter_idx < 2): + input_splits = get_shard_size_list(input.shape[scatter_idx], seq_world_size) + input = input.transpose(0, scatter_idx).contiguous() + local_heads = input_splits[groups._get_sequence_parallel_rank()] + output_splits = [local_heads] * seq_world_size + + output_shape = [seq_world_size * local_heads] + list(input.shape[1:]) + output = torch.empty(output_shape, device=input.device, dtype=input.dtype) + + dist.all_to_all_single(output,input,output_split_sizes=output_splits,\ + input_split_sizes=input_splits,group=group) + + ###[seq_ws*local_heads, ..] to [seq_ws, local_heads, ..] + output = output.view(seq_world_size, local_heads, *output.shape[1:]) + ###[seq_ws,local_heads,b,seq_len,..] to [seq_ws,seq_len,b,local_heads,..] + output = output.transpose(1, 3).contiguous() + ###[seq_ws*local_seq_len, b, local_heads,...] + output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous() + if scatter_idx < 2: + input = input.view(input.shape[0], input.shape[1], -1) + input = input.reshape(input.shape[0] * input.shape[2], input.shape[1]) + local_seq_len_with_heads = int(input.shape[0] / seq_world_size) + input_splits = [local_seq_len_with_heads] * seq_world_size + output_splits = get_shard_size_list(get_num_kv_heads(), seq_world_size) + + coeff = int(local_seq_len_with_heads / output_splits[groups._get_sequence_parallel_rank()]) + #uneven seq_world_size coeff , local_heads/total_heads. + heads_scale_coeff = get_num_kv_heads() / get_shard_size_list( + get_num_kv_heads(), seq_world_size)[groups._get_sequence_parallel_rank()] + output_splits = [i * coeff for i in output_splits] + output_d1_size = int(heads_scale_coeff * local_seq_len_with_heads) + total_seq_len = int(inp_shape[gather_idx] * heads_scale_coeff) + output = torch.empty(output_d1_size, input.shape[1], device=input.device, dtype=input.dtype) + + dist.all_to_all_single(output,input,output_split_sizes=output_splits, \ + input_split_sizes=input_splits,group=group) + + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + output2_shape= inp_shape[: gather_idx] + \ + [total_seq_len,] + \ + inp_shape[gather_idx + 1:] + output = output.reshape(output2_shape) + return output + + def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) + + num_heads = input.shape[2] + + if get_num_kv_heads() is not None or num_heads % seq_world_size != 0: + # assume here that the number of heads for q is consistent with kv + # or require additional logic + if get_num_kv_heads() is None: + # set heads at first call by ``num_total_heads``. + # then use ``get_num_kv_heads() is not None`` to re-entry uneven path. + set_num_kv_heads(num_heads) + assert async_op == False, "uneven head sp does not support async op" + return uneven_heads_all2all(input, scatter_idx, gather_idx, group) + if batch_dim_idx == 0: # b, s, n, h if scatter_idx < 2: diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 9dd288ef46dbb..e9550a0ec25ad 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -484,6 +484,8 @@ def _get_sequence_parallel_rank(): global mpu if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): return mpu.get_sequence_parallel_rank() + if mesh_device is not None: + return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel")) return 0 diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 915c89e0b00a9..4a6d11c1c6294 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -11,6 +11,7 @@ from unit.common import DistributedTest from deepspeed.sequence.layer import _SeqAllToAll from unit.util import skip_on_arch +from unit.simple_model import * #Use mesh device to create data and sequence parallel group @@ -75,3 +76,43 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_ # Check outputs are the same as input for i in range(1, len(outputs)): assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [3, 7]) +@pytest.mark.parametrize("head_dim", [16, 32]) +class TestUlyssesAll2All_odd(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + skip_on_arch(min_arch=8) + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2)) + #4D tensor : b,s,h,d or s,b,h,d + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + scatter_idx = 2 + batch_dim_idx = 0 + outputs = [] + seq_dims = [0] #seq first API + #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed + ## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808 + for seq_dim in seq_dims: + gather_idx = seq_dim + #first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + #No op + # second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs are the same as input + for i in range(1, len(outputs)): + assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"