Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable starcode((kv_head=1)) autotp #4896

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
import re


Expand All @@ -17,7 +17,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):


def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack']
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn']

if mp_size == 1:
return False
Expand All @@ -38,6 +38,7 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
"MptBlock": 'glmtype',
"BaichuanLayer": 'glmtype',
"DecoderLayer": 'glmtype',
"GPTBigCodeBlock": 'bigcodetype'
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
Expand Down Expand Up @@ -74,6 +75,14 @@ def _bloom_type_transpose(input, mp_size):
split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]

def _bigcode_type_transpose(input, mp_size):
n_embd = get_n_embd()
q = input[:n_embd]
kv = input[n_embd:]
shape = q.shape
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
Expand All @@ -87,6 +96,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
return _codegen_type_transpose(src, mp_size)
elif fused_qkv_type == 'glmtype':
return _glm_type_transpose(src, mp_size)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

Expand Down
14 changes: 13 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -278,6 +278,18 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)

# 4.1 Get n_embd
n_embd = None
multi_query_n_embd_names = ['n_embd']
for name in multi_query_n_embd_names:
if hasattr(model_config, name):
n_embd = getattr(model_config, name)
if n_embd != None:
break

# 4.2 set n_embd
set_n_embd(n_embd)

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
10 changes: 10 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def set_num_kv_heads(num):
num_kv_heads = num


def set_n_embd(num):
global n_embd
n_embd = num


def get_num_kv_heads():
global num_kv_heads
return num_kv_heads
Expand All @@ -32,6 +37,11 @@ def get_shard_size(total_size, mp_size, rank=None):
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"


def get_n_embd():
global n_embd
return n_embd


def get_shard_size_list(total_size, mp_size):
shard_sizes = []
for i in range(mp_size):
Expand Down