Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Do not merge) (CPU) aggregation of few recent fixes/optimizations #3920

Closed
wants to merge 102 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
c52d1e2
add show_straggler argument to log_summary()
delock May 19, 2023
de368db
Show straggler effect logging in seperate table
delock May 20, 2023
6884e33
fix formatting
delock May 20, 2023
206d455
add docs for log_summary with straggler effect
delock May 21, 2023
c586171
Merge branch 'master' into gma/log_summary_straggler
delock May 24, 2023
35c72df
Merge branch 'master' into gma/log_summary_straggler
delock May 25, 2023
21975fd
Merge branch 'master' into gma/log_summary_straggler
delock May 26, 2023
4630189
Merge branch 'master' into gma/log_summary_straggler
delock Jun 2, 2023
7b4db63
Merge branch 'master' into gma/log_summary_straggler
delock Jun 7, 2023
4a9ad5d
fix opt-350m shard loading issue in AutoTP
sywangyi May 24, 2023
05f9732
Merge branch 'master' into gma/log_summary_straggler
delock Jun 12, 2023
becc759
Merge branch 'gma/log_summary_straggler' into gma/run-opt-branch
delock Jun 12, 2023
d5552ef
init version of CCLBacked allreduce_latency
delock Jun 26, 2023
f0ea3eb
remove torch-ccl as dependency
delock Jun 26, 2023
6caf695
init allreduce for latency without actual reduce operation
delock Jun 28, 2023
f689f22
first version of SHM based direct allreduce
delock Jun 28, 2023
bc48c7e
tweak reduce kernel
delock Jun 29, 2023
53b4846
SHM allreduce support 2-8 ranks
delock Jun 29, 2023
4b10db9
clean up
delock Jun 30, 2023
1c10c66
remove oneCCL binding for pytorch from workflow, use gloo to bootstra…
delock Jun 30, 2023
3151996
add gpt-neox autotp support
delock Jun 30, 2023
c5dd6dc
fallback to oneccl if input is too large
delock Jun 30, 2023
f402f0b
code clean up
delock Jun 30, 2023
6ecf721
first clean up code
delock Jun 30, 2023
d6a3ac8
add checks for allreduce_low_latency, remove warning
delock Jun 30, 2023
3364a93
remove redudant declaration, fix 2 ranks
delock Jun 30, 2023
afc67f6
remove avx512f path
delock Jul 3, 2023
6f05cbf
check whether buffer size is divisible by 16
delock Jul 4, 2023
4a38410
autoTP linear allreduce should go to allreduce_low_latency
delock Jul 4, 2023
c88f3bd
cleanup profile code
delock Jul 4, 2023
bd5ea3f
make SHM buffer larger
delock Jul 4, 2023
3feebc1
add mfence to ensure memory order before each buffer state update
delock Jul 5, 2023
c377d93
Merge branch 'up-master' into gma/run-opt-branch
delock Jul 6, 2023
0ee639f
fix typo
delock Jul 6, 2023
a75064a
fix typo
delock Jul 6, 2023
9a11603
add temp print cmd
delock Jul 6, 2023
4107634
fix error in HBM path
delock Jul 6, 2023
3723eb1
remove cmd print
delock Jul 6, 2023
accf0e3
remove buffer reuse in linear allreduce since it does not bring visib…
delock Jul 6, 2023
0d2699a
add support for fakenuma support
delock Jul 10, 2023
0ac06bc
fix llama meta data error when model device is meta and LLaMa lm_head…
baodii Jul 6, 2023
e86fc83
fix llama meta data error when model device is meta and LLaMa lm_head
delock Jul 10, 2023
0be3b7e
support FP32 SHM allreduce
delock Jul 11, 2023
0706acd
allow number of heads not divisible by number of ranks
delock Jul 20, 2023
43d2c67
Update replace_module.py
baodii Jul 21, 2023
1e90f03
Update replace_module.py
baodii Jul 21, 2023
0bf785f
get num_heads from model config, more robust
delock Jul 21, 2023
72b9e1a
simplify logic where num_head itself is sharded
delock Jul 21, 2023
5ed9a56
name tweaks
delock Jul 21, 2023
73f499d
make code more robust where num_attention_heads may not be defined in…
delock Jul 21, 2023
48322c7
Merge branch 'master' into gma/uneven_heads
delock Jul 21, 2023
f14e290
Merge branch 'master' into gma/uneven_heads
delock Jul 24, 2023
b62317c
Merge branch 'master' into gma/uneven_heads
loadams Jul 24, 2023
12c0628
support num_key_value_heads < num_attention_heads which is used by ll…
delock Jul 25, 2023
8f23d9b
add test for 5 ranks
delock Jul 25, 2023
9c53bd7
change odd rank # to 3 to avoid test skip
delock Jul 25, 2023
413224b
Merge branch 'master' into gma/uneven_heads
tjruwase Jul 25, 2023
a04fa97
Run SHM allreduce's reduce kernel with openmp to further improve perf…
delock Jul 27, 2023
c432325
Merge branch 'baodi/fix_llama' into gma/run-opt-branch
delock Jul 27, 2023
78d6667
Merge branch 'master' into gma/uneven_heads
delock Aug 9, 2023
27fde30
add get_shard_size function
delock Aug 9, 2023
8e1fd27
modify sharding mechanism according to latest auto TP
delock Aug 10, 2023
e21231d
Add optimizations for lm_head & embed_out (#11)
blzheng Aug 16, 2023
9a6bc12
Merge branch 'master' into gma/uneven_heads
delock Aug 16, 2023
2dac94f
fix accuracy issue
delock Aug 17, 2023
885f6a3
Merge branch 'master' into gma/uneven_heads
delock Aug 17, 2023
7ffd811
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 18, 2023
40659ba
Merge branch 'master' into gma/uneven_heads
tjruwase Aug 22, 2023
71f9f40
fix format
delock Aug 21, 2023
db9db6b
skip tests with fusedqkv
delock Aug 23, 2023
72531c0
Merge branch 'master' into gma/uneven_heads
delock Aug 23, 2023
9d5eae3
remove skip of fusedqkv tests
delock Aug 23, 2023
25e656d
skip test fusedqkv with odd number of ranks
delock Aug 23, 2023
590da97
Merge branch 'gma/uneven_heads' into gma/run-opt-branch-rebase
delock Aug 24, 2023
5eba475
fix lm head overriden issue, move it from checkpoint in-loop loading …
sywangyi Aug 24, 2023
6f9e4f2
Merge branch 'master' into checkpt_lm_head
loadams Aug 24, 2023
6ef9093
change all_reduce_low_latency to inference_all_reduce
delock Aug 25, 2023
5efdc8f
merge lm-head updates from lyj/lmhead_tp branch
delock Aug 25, 2023
6f9565b
Merge branch 'checkpt_lm_head' into gma/run-opt-branch
delock Aug 29, 2023
370ad5e
Merge branch 'master' into gma/run-opt-branch-rebase
delock Aug 29, 2023
b66b020
Merge branch 'master' into gma/run-opt-branch-rebase
delock Aug 30, 2023
aa64514
cherry pick fix for activation size not divisible by attention heads
delock Aug 30, 2023
8b0a887
Support uneven sharding for lm_head
delock Sep 5, 2023
f15e6d4
fix CPU loading model OOM. (#13)
Yejing-Lai Sep 13, 2023
5718a88
Merge branch 'up-master' into gma/run-opt-branch-rebase
delock Sep 13, 2023
f02d40f
merge latest change in uneven_heads
delock Sep 13, 2023
3a8ad63
move tp_shard to module_inject
delock Sep 13, 2023
f0ef3ea
support baichuan model (#14)
baodii Sep 27, 2023
5ab9d58
Merge branch 'up-master' into gma/run-opt-branch-rebase
delock Oct 13, 2023
a05bd5b
Merge branch 'up-master' into gma/run-opt-branch-rebase
delock Oct 20, 2023
91f56a2
fix bug in lm_head, cherry pick from #4522
delock Oct 20, 2023
2016f30
Merge branch 'gma/run-opt-branch-rebase' into gma/run-opt-branch
delock Oct 27, 2023
57ff508
fix uneven heads issue (#25)
Yejing-Lai Oct 31, 2023
09a348c
fix imbalance autotp issue (#31)
Yejing-Lai Nov 16, 2023
14f5058
fix splt shape < 64 issue & add num_kv_heads to mp_params (#33)
Yejing-Lai Nov 21, 2023
8d60432
Baodi/support baichuan (#23)
baodii Nov 23, 2023
0ebb1ed
fix Baichuan-7B qkv order error (#35)
baodii Nov 23, 2023
547ac96
fix baichuan lm_head replace issue (#34)
Yejing-Lai Nov 23, 2023
cd070bf
shorten inference_all_reduce call stack (#37)
Liangliang-Ma Nov 27, 2023
e8ab894
Enable starcode autotp (#38)
Yejing-Lai Dec 15, 2023
092b0f2
fix falcon-40b accuracy issue (#39)
Yejing-Lai Jan 4, 2024
94873fe
fix t5 and mistral model load from config meta tensor bug (#42)
daisyden Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes)
}
}

void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
Expand All @@ -562,7 +562,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> g
data.numel(),
get_ccl_datatype(data.scalar_type()),
get_ccl_reduce_op(op, data),
_get_comm_from_group(group))
_get_comm_from_group())
.wait());
return;
}
Expand Down
46 changes: 31 additions & 15 deletions deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def is_initialized(self):

def run_collective(self, name, **kwargs):
if name in self.available_coll:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'group' in kwargs:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'dst' in kwargs:
kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
if 'src' in kwargs:
Expand All @@ -71,23 +72,38 @@ def run_collective(self, name, **kwargs):
return CCLHandler(self.ccl_comm_op)
else:
func = "super(CCLBackend, self)." + name
return eval(func)(*(kwargs.values()))
eval(func)(*(kwargs.values()))
return CCLHandler(self.ccl_comm_op)

def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
return self.run_collective(name="all_reduce_caching",
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
name = "all_reduce_caching"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
else:
return self.run_collective(name=name,
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
else:
return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
name = "all_reduce"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
name = "inference_all_reduce"
if name in self.available_coll:
return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op)

def broadcast(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
Expand Down Expand Up @@ -120,11 +136,11 @@ def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes
input_split_sizes=input_split_sizes,
group=group)

def send(self, tensor, dst, group=None, async_op=False):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)
def send(self, tensor, dst, group=None, tag=0):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, tag=tag)

def recv(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)
def recv(self, tensor, src, group=None, tag=0):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, tag=tag)

def gather(self, tensor, gather_list, dst, group=None, async_op=False):
return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
Expand Down Expand Up @@ -170,7 +186,7 @@ def get_all_ranks_from_group(self, group):
while True:
results.append(super(CCLBackend, self).get_global_rank(group, rank))
rank += 1
except ValueError:
except (ValueError, RuntimeError):
pass
if tuple(results) not in self.groups:
self._new_group(results, group)
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from ..module_inject.auto_tp import AutoTP

from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor

from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask
from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention
from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

Expand Down Expand Up @@ -220,6 +221,10 @@ def build_alibi_tensor(self):
if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'):
self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor
self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor
if hasattr(self.module, 'model'):
if hasattr(self.module.model, 'get_alibi_mask'):
self.module.model.get_alibi_mask_orig = self.module.model.get_alibi_mask
self.module.model.__class__.get_alibi_mask = get_alibi_mask

def build_attn_bias(self):
if hasattr(self.module, 'transformer'):
Expand Down
30 changes: 18 additions & 12 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -120,7 +121,9 @@ class Loading():

def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm", "MistralRMSNorm", "T5LayerNorm",
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
Expand Down Expand Up @@ -312,8 +315,9 @@ def _replace(self, child, name, conv_linear_layer):

if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(
(weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

Expand Down Expand Up @@ -342,14 +346,15 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split((weight_shape[0]) // self.mp_size,
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(
(weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
Expand All @@ -366,13 +371,14 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name),
dim=1)
else:
data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand All @@ -382,12 +388,12 @@ def update_mp_params(self, child):
return
for param in [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads"
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads"
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
setattr(child, param, param_val // self.mp_size)
#assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
setattr(child, param, get_shard_size(param_val, self.mp_size))
setattr(child, "replaced", True)

def update_linear_policies(self):
Expand Down
17 changes: 13 additions & 4 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepspeed import comm as dist
import torch
from typing import Optional
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -51,14 +52,22 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size())
offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()])
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)

def get_alibi_mask(self, tensor, seq_length_with_past):
mask = self.get_alibi_mask_orig(tensor, seq_length_with_past)
if not self.training and dist.is_initialized():
num_heads_per_rank = get_shard_size(self.n_head, dist.get_world_size())
offset = sum(get_shard_size_list(self.n_head, dist.get_world_size())[0:dist.get_rank()])
mask = mask[offset:num_heads_per_rank + offset, :seq_length_with_past, :seq_length_with_past]

return mask

def build_mpt_atten_bias_tensor(self,
device,
Expand All @@ -72,8 +81,8 @@ def build_mpt_atten_bias_tensor(self,
prefix_mask=prefix_mask,
sequence_id=sequence_id)
if dist.is_initialized():
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size())
offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()])
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask

Expand Down
38 changes: 27 additions & 11 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
import re


Expand All @@ -16,7 +17,8 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):


def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv']
# 'c_attn' is for starcoder
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn']

if mp_size == 1:
return False
Expand All @@ -35,22 +37,27 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
'GLMBlock': 'glmtype',
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
"FalconDecoderLayer": 'bloomtype',
"BaichuanLayer": 'glmtype',
"DecoderLayer": 'glmtype',
"GPTBigCodeBlock": 'bigcodetype' # starcoder
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
assert get_num_kv_heads() % (
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
dst_shape = get_shard_size(shape[0], mp_size)
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
Expand All @@ -59,18 +66,25 @@ def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
src_split = torch.split(input, shape[0] // 3, dim=0)

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
return split_fusedqkv[gpu_index]

def _bloom_type_transpose(input, mp_size):
shape = input.shape
dst_shape = shape[0] // mp_size
return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]

def _bigcode_type_transpose(input, mp_size):
n_embd = get_n_embd()
q = input[:n_embd]
kv = input[n_embd:]
shape = q.shape

split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

Expand All @@ -85,6 +99,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
return _codegen_type_transpose(src, mp_size)
elif fused_qkv_type == 'glmtype':
return _glm_type_transpose(src, mp_size)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class LinearAllreduce(nn.Module):
Expand Down Expand Up @@ -47,10 +48,9 @@ def __init__(
self.world_size = world_size

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
Expand Down
Loading