Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

neuron models for e-prop implementation #28

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

.ipynb_checkpoints
*/.ipynb_checkpoints/*
docs/examples/surrogate_gradient/data

# datasets
.h5
Expand Down
1 change: 1 addition & 0 deletions docs/examples/surrogate_gradient/shd_eprop.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Be sure to go give it a star on Github: https://github.com/kmheckel/spyx
examples/surrogate_gradient/shd_sg_neuron_model_comparison
examples/surrogate_gradient/shd_sg_surrogate_comparison
examples/surrogate_gradient/shd_sg_template
examples/surrogate_gradient/shd_eprop

Indices and tables
==================
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
'loaders' : [
'tonic',
'torchvision',
'sklearn'
'scikit-learn'
]
}

Expand Down
32 changes: 29 additions & 3 deletions spyx/axn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def custom(bwd=lambda x: x,

It is assumed that the input to this layer has already had it's threshold subtracted within the neuron model dynamics.

The default behavior is a Heaviside forward activation with a stragiht through estimator surrogate gradient.
The default behavior is a Heaviside forward activation with a straight through estimator surrogate gradient.

:bwd: Function that calculates the gradient to be used in the backwards pass.
:fwd: Forward activation/spiking function. Default is the heaviside function centered at 0.
Expand Down Expand Up @@ -69,10 +69,10 @@ def triangular(k=2):
:return: JIT compiled triangular surrogate gradient function.
"""

def grad_traingle(x):
def grad_triangle(x):
return jnp.maximum(0, 1-jnp.abs(k*x))

return custom(grad_traingle, heaviside)
return custom(grad_triangle, heaviside)


def arctan(k=2):
Expand Down Expand Up @@ -119,3 +119,29 @@ def grad_superspike(x):
return 1 / (1 + k*jnp.abs(x))**2

return custom(grad_superspike, heaviside)

def abs_linear(dampening_factor=0.3):
"""
This function implements the SpikeFunction surrogate gradient activation function for a spiking neuron.

It was introduced in Bellec, Guillaume, et al. Long short-term memory and learning-to-learn in networks of spiking neurons.
arXiv:1803.09574, arXiv, 25 dec 2018. arXiv.org,
https://doi.org/10.48550/arXiv.1803.09574.

:v_scaled: The normalized membrane potential of the neuron scaled by the threshold.
:dampening_factor: The dampening factor for the surrogate gradient,
which can improve the stability of the training process
for deep networks. Default is 0.3.
"""
def fwd(v_scaled):
z_ = jnp.greater(v_scaled, 0.)
z_ = z_.astype(jnp.float32)
return z_

def grad(v_scaled):
dz_dv_scaled = jnp.maximum(1 - jnp.abs(v_scaled), 0).astype(v_scaled.dtype)
dz_dv_scaled *= dampening_factor

return dz_dv_scaled

return custom(grad, fwd)
208 changes: 205 additions & 3 deletions spyx/nn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import jax
import jax.numpy as jnp
import haiku as hk
from .axn import superspike
from .axn import superspike, abs_linear

from collections.abc import Sequence
from typing import Optional, Union
import warnings

from collections import namedtuple

#needs fixed.
class ALIF(hk.RNNCore):
"""
Adaptive LIF Neuron based on the model used in LSNNs:
Adaptive LIF Neuron based on the model used in LSNNs

Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W.
Long short- term memory and learning-to-learn in networks of spiking neurons.
32nd Conference on Neural Information Processing Systems (2018).

"""


def __init__(self, hidden_shape, beta=None, gamma=None,
threshold = 1,
activation = superspike(),
Expand Down Expand Up @@ -80,6 +81,207 @@ def __call__(self, x, VT):
# not sure if this is borked.
def initial_state(self, batch_size): # this might need fixed to match CuBaLIF...
return jnp.zeros((batch_size,) + tuple(2*s for s in self.hidden_shape))


CustomALIFStateTuple = namedtuple('CustomALIFStateTuple', ('s', 'z', 'r', 'z_local'))


class RecurrentLIFLight(hk.RNNCore):
"""
Recurrent Adaptive Leaky Integrate and Fire neuron model with threshold adaptation.
It can be used for LIF only by setting beta to 0.

Original code from https://github.com/IGITUGraz/eligibility_propagation for RecurrentLIFLight
Copyright 2019-2020, the e-prop team:
Guillaume Bellec, Franz Scherr, Anand Subramoney, Elias Hajek, Darjan Salaj, Robert Legenstein, Wolfgang Maass
from the Institute for theoretical computer science, TU Graz, Austria.

Params
------
n_rec: int
Number of recurrent neurons.
tau: float or array
Membrane time constant (ms)
thr: float or array
Firing threshold.
dt: float
Time step (ms)
dtype:
Data type.
dampening_factor: float
Dampening factor for the surrogate gradient (see abs_linear).
tau_adaptation: float or array
Time constant for threshold adaptation (ALIF model)
beta: float or array
Decay rate for threshold adaptation (ALIF model)
tag: str
parameter tag.
stop_gradients: bool
Whether to stop gradients.
If True, e-prop will be applied
If False, exact BPTT will be applied
w_rec_init: array
Initial value for the recurrent weights.
n_refractory: float
Refractory period (ms)
rec: bool
Whether to include recurrent connections.
name: str
Name of the Haiku module.
"""

def __init__(self,
n_rec, tau=20., thr=.615, dt=1., dtype=jnp.float32, dampening_factor=0.3,
tau_adaptation=200., beta=.16, tag='',
stop_gradients=False, w_rec_init=None, n_refractory=1, rec=True,
name="RecurrentLIFLight"):
super().__init__(name=name)

self.n_refractory = n_refractory
self.tau_adaptation = tau_adaptation
self.beta = beta
self.decay_b = jnp.exp(-dt / tau_adaptation)

if jnp.isscalar(tau): tau = jnp.ones(n_rec, dtype=dtype) * jnp.mean(tau)
if jnp.isscalar(thr): thr = jnp.ones(n_rec, dtype=dtype) * jnp.mean(thr)

tau = jnp.array(tau, dtype=dtype)
dt = jnp.array(dt, dtype=dtype)
self.rec = rec

self.dampening_factor = dampening_factor
self.stop_gradients = stop_gradients
self.dt = dt
self.n_rec = n_rec
self.data_type = dtype

self._num_units = self.n_rec

self.tau = tau
self._decay = jnp.exp(-dt / tau)
self.thr = thr

if rec:
init_w_rec_var = w_rec_init if w_rec_init is not None else hk.initializers.TruncatedNormal(1./jnp.sqrt(n_rec))
self.w_rec_var = hk.get_parameter("w_rec" + tag, (n_rec, n_rec), dtype, init_w_rec_var)

self.recurrent_disconnect_mask = jnp.diag(jnp.ones(n_rec, dtype=bool))

self.w_rec_val = jnp.where(self.recurrent_disconnect_mask, jnp.zeros_like(self.w_rec_var), self.w_rec_var)

self.built = True

def initial_state(self, batch_size, dtype=jnp.float32):
"""
Initialize the state of the neuron model.

:batch_size: tuple
Batch size.
:dtype:
Data type.
"""
n_rec = self.n_rec

s0 = jnp.zeros(shape=(batch_size, n_rec, 2), dtype=dtype)
z0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
z_local0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)
r0 = jnp.zeros(shape=(batch_size, n_rec), dtype=dtype)

return CustomALIFStateTuple(s=s0, z=z0, r=r0, z_local=z_local0)

def compute_z(self, v, b):
"""
Compute the surrogate gradient.
"""
adaptive_thr = self.thr + b * self.beta
v_scaled = (v - adaptive_thr) / self.thr
z = abs_linear(self.dampening_factor)(v_scaled)
z = z * 1 / self.dt

return z

def __call__(self, inputs, state):
decay = self._decay

z = state.z
z_local = state.z_local
s = state.s

if self.stop_gradients:
z = jax.lax.stop_gradient(z)

i_in = inputs.reshape(-1, self.n_rec)

if self.rec:
if len(self.w_rec_val.shape) == 3:
i_rec = jnp.einsum('bi,bij->bj', z, self.w_rec_val)
else:
i_rec = jnp.matmul(z, self.w_rec_val)

i_t = i_in + i_rec
else:
i_t = i_in

v, b = s[..., 0], s[..., 1]
new_b = self.decay_b * b + z_local

I_reset = z * self.thr * self.dt
new_v = decay * v + i_t - I_reset

is_refractory = state.r > 0
zeros_like_spikes = jnp.zeros_like(z)
z_computed = self.compute_z(new_v, new_b)
new_z = jnp.where(is_refractory, zeros_like_spikes, z_computed)
new_z_local = jnp.where(is_refractory, zeros_like_spikes, z_computed)
new_r = state.r + self.n_refractory * new_z - 1
new_r = jnp.clip(new_r, 0., float(self.n_refractory))

if self.stop_gradients:
new_r = jax.lax.stop_gradient(new_r)
new_s = jnp.stack((new_v, new_b), axis=-1)

new_state = CustomALIFStateTuple(s=new_s, z=new_z, r=new_r, z_local=new_z_local)
return new_z, new_state


class LeakyLinear(hk.RNNCore):
"""
Leaky real-valued output neuron from the code of the paper https://github.com/IGITUGraz/eligibility_propagation

To be replace with Linear + LI in the future.

"""
def __init__(self, n_in, n_out, kappa, dtype=jnp.float32, name="LeakyLinear"):
super().__init__(name=name)
self.n_in = n_in
self.n_out = n_out
self.kappa = kappa

self.dtype = dtype

self.weights = hk.get_parameter("weights", shape=[n_in, n_out], dtype=dtype,
init=hk.initializers.TruncatedNormal(1./jnp.sqrt(n_in)))

# self.weights = hk.get_parameter("weights", shape=[n_in, n_out], dtype=dtype,
# init=hk.initializers.Constant(
# jnp.eye(n_in, n_out)
# ))

self._num_units = self.n_out
self.built = True


def initial_state(self, batch_size, dtype=jnp.float32):
s0 = jnp.zeros(shape=(batch_size, self.n_out), dtype=dtype)
return s0

def __call__(self, inputs, state, scope=None, dtype=jnp.float32):
if len(self.weights.shape) == 3:
outputs = jnp.einsum('bi,bij->bj', inputs, self.weights)
else:
outputs = jnp.matmul(inputs, self.weights)
new_s = self.kappa * state + (1 - self.kappa) * outputs
return new_s, new_s

class LI(hk.RNNCore):
"""
Expand Down
Loading