diff --git a/spyx/axn.py b/spyx/axn.py index 5af98de..7fe9530 100644 --- a/spyx/axn.py +++ b/spyx/axn.py @@ -77,6 +77,36 @@ def __call__(self, U): return self.f(U) +class Triangular: + """ + Triangular activation inspired by Esser et. al. Very simple. https://arxiv.org/abs/1603.08270 + + """ + def __init__(self, scale_factor=0.5): + self.k = scale_factor + + @jax.jit + def _grad(x): + return jnp.maximum(0, 1-jnp.abs(x)) + + @jax.custom_vjp + def f(U): # primal function + return (U>0).astype(jnp.float16) + + # returns value, grad context + def f_fwd(U): + return f(U), U + + # accepts context, primal val + def f_bwd(U, grad): + return (grad * self._grad(self.k * U) , ) + + f.defvjp(f_fwd, f_bwd) + self.f = f + + def __call__(self, U): + return self.f(U) + # Surrogate functions class Arctan: """