Skip to content

Commit

Permalink
Implement Router Z-loss (#151)
Browse files Browse the repository at this point in the history
* Router zloss

* pre-commit

* Add zloss tests
  • Loading branch information
josejg authored Sep 9, 2024
1 parent 66d7894 commit cc7614e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 6 deletions.
4 changes: 4 additions & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class Arguments:
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)

# Router Z-loss arguments
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
moe_zloss_in_fp32: bool = False

def __post_init__(self):
if self.__getattribute__('mlp_impl') == 'grouped':
grouped_gemm.assert_grouped_gemm_is_available()
Expand Down
38 changes: 37 additions & 1 deletion megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,40 @@
from megablocks.layers import common
from megablocks.layers.arguments import Arguments

_ROUTER_LOGITS = []


def _save_router_logits(logits: torch.Tensor, args: Arguments):
if args.moe_zloss_weight == 0:
return
global _ROUTER_LOGITS
_ROUTER_LOGITS.append(logits)


def clear_router_zloss():
global _ROUTER_LOGITS
_ROUTER_LOGITS.clear()


def batched_router_zloss(args: Arguments):
global _ROUTER_LOGITS

if args.moe_zloss_weight == 0:
import warnings
warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
return 0

logits_per_router = _ROUTER_LOGITS

if args.moe_zloss_in_fp32:
logits_per_router = [logits.float() for logits in logits_per_router]

unscaled_zloss_per_router = torch.stack([
torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
])

return args.moe_zloss_weight * unscaled_zloss_per_router


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign tokens uniformly
Expand Down Expand Up @@ -60,7 +94,9 @@ def forward(self, x: torch.Tensor):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
logits = self.layer(x.view(-1, x.shape[-1]))
_save_router_logits(logits, self.args)
scores = logits.softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(
Expand Down
36 changes: 36 additions & 0 deletions tests/layers/dmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
from megablocks.layers.router import batched_router_zloss, clear_router_zloss
from tests.layers.architectures import FFN

# min size: (1, 2, 128, 2, 1)
Expand Down Expand Up @@ -50,6 +51,7 @@ def construct_moes(
moe_capacity_factor: int = 1,
moe_top_k: int = 1,
mlp_impl: str = 'sparse',
moe_zloss_weight: float = 0,
):
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
args = Arguments(
Expand All @@ -64,6 +66,7 @@ def construct_moes(
mlp_impl=mlp_impl,
fp16=False,
bf16=True,
moe_zloss_weight=moe_zloss_weight,
)

mlp = FFN(args)
Expand Down Expand Up @@ -142,6 +145,39 @@ def test_dmoe_forward_backward(
clear_load_balancing_loss()


@pytest.mark.gpu
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS)
def test_dmoe_forward_backward_with_zloss(
bs: int,
sl: int,
hs: int,
num_experts: int,
top_k: int,
mlp_impl: str,
):
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()
x.requires_grad_(True)

args, _, _, layer = construct_moes(
hidden_size=hs,
ffn_hidden_size=hs * 2,
moe_num_experts=num_experts,
moe_top_k=top_k,
mlp_impl=mlp_impl,
moe_zloss_weight=1e-3,
)

out, _ = layer(x)
assert out.shape == x.shape
loss = out.sum() + batched_load_balancing_loss(args) + batched_router_zloss(args)
loss.backward()
assert x.grad is not None
layer.zero_grad(set_to_none=True)
x.grad = None
clear_load_balancing_loss()
clear_router_zloss()


@pytest.mark.gpu
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
def test_dmoe_forward_vs_baseline(
Expand Down
44 changes: 39 additions & 5 deletions tests/layers/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from megablocks.layers.arguments import Arguments
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
from megablocks.layers.router import batched_router_zloss, clear_router_zloss
from tests.layers.architectures import FFN

_FORWARD_TESTS = (
Expand All @@ -33,11 +34,12 @@


def construct_moe(
hidden_size,
ffn_hidden_size,
moe_num_experts=1,
moe_capacity_factor=1,
moe_top_k=1,
hidden_size: int,
ffn_hidden_size: int,
moe_num_experts: int = 1,
moe_capacity_factor: int = 1,
moe_top_k: int = 1,
moe_zloss_weight: float = 0,
):
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
args = Arguments(
Expand All @@ -47,6 +49,7 @@ def construct_moe(
moe_capacity_factor=moe_capacity_factor,
moe_top_k=moe_top_k,
init_method=init_method,
moe_zloss_weight=moe_zloss_weight,
)

mlp = FFN(args)
Expand Down Expand Up @@ -109,6 +112,37 @@ def test_moe_forward_backward(
clear_load_balancing_loss()


@pytest.mark.gpu
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
def test_moe_forward_backward_with_zloss(
bs: int,
sl: int,
hs: int,
num_experts: int,
top_k: int,
):
x = torch.randn(sl, bs, hs).half().cuda()
x.requires_grad_(True)

args, _, layer = construct_moe(
hidden_size=hs,
ffn_hidden_size=hs * 2,
moe_num_experts=num_experts,
moe_top_k=top_k,
moe_zloss_weight=1e-3,
)

out, _ = layer(x)
assert out.shape == x.shape

loss = out.sum() + batched_load_balancing_loss(args)
loss.backward()
layer.zero_grad(set_to_none=True)
x.grad = None
clear_load_balancing_loss()
clear_router_zloss()


@pytest.mark.gpu
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
Expand Down

0 comments on commit cc7614e

Please sign in to comment.