Skip to content

Commit

Permalink
added triangular activation
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Aug 9, 2023
1 parent 0982033 commit 5ca159c
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions spyx/axn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ def __call__(self, U):
return self.f(U)


class Triangular:
"""
Triangular activation inspired by Esser et. al. Very simple. https://arxiv.org/abs/1603.08270
"""
def __init__(self, scale_factor=0.5):
self.k = scale_factor

@jax.jit
def _grad(x):
return jnp.maximum(0, 1-jnp.abs(x))

@jax.custom_vjp
def f(U): # primal function
return (U>0).astype(jnp.float16)

# returns value, grad context
def f_fwd(U):
return f(U), U

# accepts context, primal val
def f_bwd(U, grad):
return (grad * self._grad(self.k * U) , )

f.defvjp(f_fwd, f_bwd)
self.f = f

def __call__(self, U):
return self.f(U)

# Surrogate functions
class Arctan:
"""
Expand Down

0 comments on commit 5ca159c

Please sign in to comment.