Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Sep 9, 2024
1 parent abc0638 commit 3ad64e2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class Arguments:
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)

# Router Z-loss arguments
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
moe_zloss_in_fp32 : bool = False
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
moe_zloss_in_fp32: bool = False

def __post_init__(self):
if self.__getattribute__('mlp_impl') == 'grouped':
Expand Down
10 changes: 6 additions & 4 deletions megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@

_ROUTER_LOGITS = []


def _save_router_logits(logits: torch.Tensor, args: Arguments):
if args.moe_zloss_weight == 0:
return
global _ROUTER_LOGITS
_ROUTER_LOGITS.append(logits)


def clear_router_zloss():
global _ROUTER_LOGITS
_ROUTER_LOGITS.clear()

def batched_router_zloss(args : Arguments):

def batched_router_zloss(args: Arguments):
global _ROUTER_LOGITS

if args.moe_zloss_weight == 0:
import warnings
warnings.warn("Call to batched_router_zloss, but moe_zloss_weight=0")
warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
return 0

logits_per_router = _ROUTER_LOGITS
Expand All @@ -33,8 +36,7 @@ def batched_router_zloss(args : Arguments):
logits_per_router = [logits.float() for logits in logits_per_router]

unscaled_zloss_per_router = torch.stack([
torch.logsumexp(logits, dim=1).square().mean()
for logits in logits_per_router
torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
])

return args.moe_zloss_weight * unscaled_zloss_per_router
Expand Down

0 comments on commit 3ad64e2

Please sign in to comment.