diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 0b6d666..6492dfd 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -69,8 +69,8 @@ class Arguments: shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32 : bool = False + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False def __post_init__(self): if self.__getattribute__('mlp_impl') == 'grouped': diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index a0b4b4e..a6deae0 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -9,22 +9,25 @@ _ROUTER_LOGITS = [] + def _save_router_logits(logits: torch.Tensor, args: Arguments): if args.moe_zloss_weight == 0: return global _ROUTER_LOGITS _ROUTER_LOGITS.append(logits) + def clear_router_zloss(): global _ROUTER_LOGITS _ROUTER_LOGITS.clear() -def batched_router_zloss(args : Arguments): + +def batched_router_zloss(args: Arguments): global _ROUTER_LOGITS if args.moe_zloss_weight == 0: import warnings - warnings.warn("Call to batched_router_zloss, but moe_zloss_weight=0") + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') return 0 logits_per_router = _ROUTER_LOGITS @@ -33,8 +36,7 @@ def batched_router_zloss(args : Arguments): logits_per_router = [logits.float() for logits in logits_per_router] unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() - for logits in logits_per_router + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router ]) return args.moe_zloss_weight * unscaled_zloss_per_router