Skip to content

Commit

Permalink
add moe topk(k>2) gate support (#5881)
Browse files Browse the repository at this point in the history
Notice some users need to use topk > 2 to train MoE models. For example:
https://huggingface.co/Qwen/Qwen2-57B-A14B/blob/main/config.json, this
PR adds support for topk (k > 2) gates.

- add topk (k>2) support
- add drop token policy based on position and probabilities.
- unit tests

---------

Co-authored-by: Kurt Chen <[email protected]>
Co-authored-by: Jin, Youzhi <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
6 people authored Aug 15, 2024
1 parent 30428d0 commit 9a3ede7
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 6 deletions.
88 changes: 83 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def einsum(rule, a, b):
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == 'se,sec->sec':
return a.unsqueeze(2) * b
elif rule == 'sec,sm->ecm':
s = a.shape[0]
e = a.shape[1]
Expand Down Expand Up @@ -191,8 +193,8 @@ def top1gating(logits: Tensor,
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)

gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))

# Create a mask for 1st's expert per token
Expand Down Expand Up @@ -369,6 +371,81 @@ def top2gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')


def topkgating(
logits: Tensor,
k: int,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
drop_policy: str = "probs",
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements TopKGating on logits."""

# everything is in fp32 in this function
# get topk gates
top_gate, top_idx = torch.topk(logits, k=k, dim=1)
# gating decisions
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])

# get topk mask
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate)

mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1)

exp_counts = torch.sum(mask, dim=0).detach().to(logits.device)

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts / k

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
# update mask and locations by capacity

if drop_policy == 'probs':
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
mask = torch.logical_and(mask, capacity_mask)
locations = torch.cumsum(mask, dim=0) - 1

elif drop_policy == "position":
locations = torch.cumsum(mask, dim=0) - 1
mask *= torch.lt(locations, capacity)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")

else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# normalize gates
gates_masked = gates * mask
gates_s = torch.sum(gates_masked, dim=-1, keepdim=True)
denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps)
gates_masked = gates_masked / denom_s

# dispatch_mask
locations_sc = _one_hot_to_float((locations * mask), capacity)

combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc)

dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts


class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
Expand Down Expand Up @@ -401,9 +478,6 @@ def __init__(self,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
Expand Down Expand Up @@ -441,9 +515,13 @@ def forward(self,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)

else:
elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
else:
gate_output = topkgating(logits, self.k,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import top1gating
from deepspeed.moe.sharded_moe import top1gating, topkgating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.utils.torch import required_torch_version

Expand Down Expand Up @@ -191,3 +191,50 @@ def test(self):
drop_tokens=False,
use_rts=True,
use_tutel=False)


class TestTopkGate(DistributedTest):

def test(self):

def check_equal(logits, cap, sparse_truth, res):
m, n = logits.shape
dispatch_mask_truth = torch.zeros(m, n, cap)
i, j, k = sparse_truth.t()
dispatch_mask_truth[i, j, k] = 1
assert (torch.equal(dispatch_mask_truth, res))

#s=4 e=4 topk=2 cap=2(s*topk/e)
logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5],
[0.1, 0.11, 0.7, 0.8]])
logits *= dist.get_rank() + 1
probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]])
check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1],
[3, 2, 1]])
position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits, 2, position_sec_sparse, position_dispatch_res)

#s=4 e=6 topk=3 cap=2(s*topk/e)
logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034],
[0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450],
[0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068],
[0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]])
logits2 *= dist.get_rank() + 1

#top3 full mask #prob_mask #postion_mask
#0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1
#0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1
#0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0
#1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0
probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2]
probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1],
[3, 0, 0], [3, 1, 1], [3, 5, 1]])
check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res)

position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1],
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)

0 comments on commit 9a3ede7

Please sign in to comment.