-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AutoTP] Make AutoTP work when num_heads not divisible by number of w…
…orkers (#4011) * allow number of heads not divisible by number of ranks * get num_heads from model config, more robust * simplify logic where num_head itself is sharded * name tweaks * make code more robust where num_attention_heads may not be defined in model_config * support num_key_value_heads < num_attention_heads which is used by llama2 * add test for 5 ranks * change odd rank # to 3 to avoid test skip * add get_shard_size function * modify sharding mechanism according to latest auto TP * fix accuracy issue * fix format * skip tests with fusedqkv * remove skip of fusedqkv tests * skip test fusedqkv with odd number of ranks * support model with n_heads in model_config * fix TestInjectionPolicy::test[fp32-t5] * fix uneven_heads on some fusedqkv types (#12) * odd support fusedqkv * fix format and clear text * better fix when activation size cannot be divided by number of heads * move tp_shard.py under module_inject * Add get_num_kv_heads in tp_shard.py * Refine according to comments * remove old comment * fix bug in getting num_kv_heads * support uneven sharding of lm_head tensor parallel --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Molly Smith <[email protected]> Co-authored-by: mzl <[email protected]> Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Michael Wyatt <[email protected]>
- Loading branch information
1 parent
8fdd9b3
commit f15cccf
Showing
7 changed files
with
121 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from deepspeed import comm as dist | ||
global num_kv_heads | ||
|
||
|
||
def set_num_kv_heads(num): | ||
global num_kv_heads | ||
num_kv_heads = num | ||
|
||
|
||
def get_num_kv_heads(): | ||
global num_kv_heads | ||
return num_kv_heads | ||
|
||
|
||
def get_shard_size(total_size, mp_size, rank=None): | ||
global num_kv_heads | ||
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division | ||
if num_kv_heads != None: | ||
if (rank == None): | ||
rank = dist.get_rank() | ||
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) | ||
return total_size * my_slices // num_kv_heads | ||
else: | ||
if total_size % mp_size == 0: | ||
return total_size // mp_size | ||
else: | ||
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" | ||
|
||
|
||
def get_shard_size_list(total_size, mp_size): | ||
shard_sizes = [] | ||
for i in range(mp_size): | ||
shard_sizes.append(get_shard_size(total_size, mp_size, i)) | ||
return shard_sizes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters