Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTP] Make AutoTP work when num_heads not divisible by number of workers #4011

Merged
merged 55 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0706acd
allow number of heads not divisible by number of ranks
delock Jul 20, 2023
0bf785f
get num_heads from model config, more robust
delock Jul 21, 2023
72b9e1a
simplify logic where num_head itself is sharded
delock Jul 21, 2023
5ed9a56
name tweaks
delock Jul 21, 2023
73f499d
make code more robust where num_attention_heads may not be defined in…
delock Jul 21, 2023
48322c7
Merge branch 'master' into gma/uneven_heads
delock Jul 21, 2023
f14e290
Merge branch 'master' into gma/uneven_heads
delock Jul 24, 2023
b62317c
Merge branch 'master' into gma/uneven_heads
loadams Jul 24, 2023
12c0628
support num_key_value_heads < num_attention_heads which is used by ll…
delock Jul 25, 2023
8f23d9b
add test for 5 ranks
delock Jul 25, 2023
9c53bd7
change odd rank # to 3 to avoid test skip
delock Jul 25, 2023
413224b
Merge branch 'master' into gma/uneven_heads
tjruwase Jul 25, 2023
78d6667
Merge branch 'master' into gma/uneven_heads
delock Aug 9, 2023
27fde30
add get_shard_size function
delock Aug 9, 2023
8e1fd27
modify sharding mechanism according to latest auto TP
delock Aug 10, 2023
9a6bc12
Merge branch 'master' into gma/uneven_heads
delock Aug 16, 2023
2dac94f
fix accuracy issue
delock Aug 17, 2023
885f6a3
Merge branch 'master' into gma/uneven_heads
delock Aug 17, 2023
7ffd811
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 18, 2023
40659ba
Merge branch 'master' into gma/uneven_heads
tjruwase Aug 22, 2023
71f9f40
fix format
delock Aug 21, 2023
db9db6b
skip tests with fusedqkv
delock Aug 23, 2023
72531c0
Merge branch 'master' into gma/uneven_heads
delock Aug 23, 2023
9d5eae3
remove skip of fusedqkv tests
delock Aug 23, 2023
25e656d
skip test fusedqkv with odd number of ranks
delock Aug 23, 2023
7f6d7f6
support model with n_heads in model_config
delock Aug 24, 2023
e3a5b77
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 24, 2023
c9ec881
Merge branch 'master' into gma/uneven_heads
delock Aug 26, 2023
f5be257
fix TestInjectionPolicy::test[fp32-t5]
delock Aug 27, 2023
b671040
fix uneven_heads on some fusedqkv types (#12)
inkcherry Aug 28, 2023
d59ff22
better fix when activation size cannot be divided by number of heads
delock Aug 30, 2023
6c3c841
Merge branch 'master' into gma/uneven_heads_rebase
delock Aug 30, 2023
58e8b24
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 1, 2023
4c6b7fa
move tp_shard.py under module_inject
delock Sep 6, 2023
18e1c5d
Merge branch 'master' into gma/uneven_heads
delock Sep 6, 2023
8ef01e2
Add get_num_kv_heads in tp_shard.py
delock Sep 7, 2023
9a61fc2
Merge branch 'master' into gma/uneven_heads
delock Sep 11, 2023
74870db
Merge branch 'master' into gma/uneven_heads
delock Sep 13, 2023
115cc20
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 13, 2023
0781c41
Refine according to comments
delock Sep 14, 2023
194337f
remove old comment
mrwyattii Sep 14, 2023
47d84ca
Merge branch 'master' into gma/uneven_heads
delock Sep 18, 2023
369eb3e
Merge branch 'master' into gma/uneven_heads
mrwyattii Sep 19, 2023
567fb9a
fix bug in getting num_kv_heads
delock Sep 20, 2023
47c83ca
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 20, 2023
d194ab0
Merge branch 'master' into gma/uneven_heads
tjruwase Sep 27, 2023
6db5ddd
Merge branch 'master' into gma/uneven_heads
delock Oct 7, 2023
698b62a
Merge branch 'up-master' into gma/uneven_heads
delock Oct 10, 2023
d75149f
support uneven sharding of lm_head tensor parallel
delock Oct 10, 2023
248532d
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
a9056fd
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
81bd29f
Merge branch 'master' into gma/uneven_heads
delock Oct 12, 2023
693a9fe
Merge branch 'master' into gma/uneven_heads
delock Oct 18, 2023
4c45a5b
Merge branch 'master' into gma/uneven_heads
delock Oct 19, 2023
a7513e1
Merge branch 'master' into gma/uneven_heads
delock Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -312,8 +313,9 @@ def _replace(self, child, name, conv_linear_layer):

if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(
(weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

Expand Down Expand Up @@ -342,14 +344,15 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split((weight_shape[0]) // self.mp_size,
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(
(weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
Expand All @@ -366,13 +369,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
else:
data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand All @@ -386,8 +389,7 @@ def update_mp_params(self, child):
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
setattr(child, param, param_val // self.mp_size)
setattr(child, param, get_shard_size(param_val, self.mp_size))
setattr(child, "replaced", True)

def update_linear_policies(self):
Expand Down Expand Up @@ -442,6 +444,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
self._replace_module(child, name, class_name)
return r_module

def get_model_num_kv_heads(self, config):
num_kv_heads = None
kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads']
for name in kv_head_names:
if hasattr(config, name):
num_kv_heads = getattr(config, name)
if num_kv_heads != None:
break
return num_kv_heads

def _replace_last_linear_module(self, r_module):
if hasattr(r_module, "lm_head"):
name = "lm_head"
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepspeed import comm as dist
import torch
from typing import Optional
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -51,8 +52,8 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size())
offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()])
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
Expand All @@ -72,8 +73,8 @@ def build_mpt_atten_bias_tensor(self,
prefix_mask=prefix_mask,
sequence_id=sequence_id)
if dist.is_initialized():
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size())
offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()])
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask

Expand Down
20 changes: 10 additions & 10 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads
import re


Expand Down Expand Up @@ -39,18 +40,19 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
assert get_num_kv_heads() % (
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
dst_shape = get_shard_size(shape[0], mp_size)
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
Expand All @@ -59,18 +61,16 @@ def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
src_split = torch.split(input, shape[0] // 3, dim=0)

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
return split_fusedqkv[gpu_index]

def _bloom_type_transpose(input, mp_size):
shape = input.shape
dst_shape = shape[0] // mp_size
return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class LinearAllreduce(nn.Module):
Expand Down Expand Up @@ -47,10 +48,9 @@ def __init__(
self.world_size = world_size

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
Expand Down
11 changes: 9 additions & 2 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -271,10 +272,16 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

# 3. Set linear policies
# 3. Try to get num_key_heads from model_config.num_key_value_heads
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)

# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)

# 5. Set linear policies
_autotp.update_linear_policies()

# 4. Replace modules
# 6. Replace modules
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)
Expand Down
39 changes: 39 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed import comm as dist
global num_kv_heads


def set_num_kv_heads(num):
global num_kv_heads
num_kv_heads = num


def get_num_kv_heads():
global num_kv_heads
return num_kv_heads


def get_shard_size(total_size, mp_size, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
if num_kv_heads != None:
if (rank == None):
rank = dist.get_rank()
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size % mp_size == 0:
return total_size // mp_size
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"


def get_shard_size_list(total_size, mp_size):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
return shard_sizes
32 changes: 32 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,38 @@ def test(
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

@pytest.mark.world_size(3)
def test_odd_world_size(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
dtype,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
pytest.skip(invalid_test_msg)

model, task = model_w_task
if model == "Salesforce/codegen-350M-mono":
pytest.skip("codegen does not supported by odd world_size")
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "3"))

pipe = pipeline(task,
model=model,
device=torch.device(get_accelerator().device_name(local_rank)),
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
Expand Down
Loading