From 84286de8ab5be0c73928a0059f50c7e2b650e4b1 Mon Sep 17 00:00:00 2001 From: mihir-db <141708001+mihir-db@users.noreply.github.com> Date: Thu, 17 Oct 2024 08:27:50 -0700 Subject: [PATCH] Update router.py (#158) --- megablocks/layers/router.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index a6deae0..2c9dcd9 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -79,9 +79,8 @@ def __init__(self, args: Arguments): args.init_method(self.layer.weight) def jitter(self, x: torch.Tensor): - assert isinstance(self.args.moe_jitter_eps, float) - low = 1.0 - self.args.moe_jitter_eps - high = 1.0 + self.args.moe_jitter_eps + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low)