Skip to content

Commit

Permalink
added discord badge
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Sep 11, 2023
1 parent 03a336a commit 18deb99
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 55 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
⚡🧠💻 Welcome to Spyx! 💻🧠⚡
============================
[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506) [![PyPI version](https://badge.fury.io/py/spyx.svg)](https://badge.fury.io/py/spyx)
[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506) [![PyPI version](https://badge.fury.io/py/spyx.svg)](https://badge.fury.io/py/spyx)[![](https://dcbadge.vercel.app/api/server/INVITEID)](https://discord.gg/TCYQFWsBwj)


![README Art](spyx.png "Spyx")

Spyx is a compact spiking neural network library built on top of DeepMind's Haiku library.
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
project = 'Spyx'
copyright = '2023, Kade Heckel'
author = 'Kade Heckel'
release = 'v0.0.45'
release = 'v0.1.8'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
1 change: 1 addition & 0 deletions spyx/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def integral_crossentropy(traces, targets, smoothing=0.3):
return optax.softmax_cross_entropy(logits, labels).mean()

# convert to function that returns compiled function
@jax.jit
def mse_spikerate(traces, targets, sparsity=0.25, smoothing=0.0):
"""
Calculate the mean squared error of the mean spike rate.
Expand Down
58 changes: 5 additions & 53 deletions spyx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self, hidden_shape, beta=None, gamma=None, threshold=1,
self.act = activation

def __call__(self, x, VT):
# this probably needs changed to be spltting an array
V, T = jnp.split(VT, 2, -1)

gamma = self.gamma
Expand Down Expand Up @@ -85,7 +84,7 @@ def __call__(self, x, Vin):
def initial_state(self, batch_size):
return jnp.zeros((batch_size,) + self.layer_shape, dtype=jnp.float32)

class IF(hk.RNNCore): # bfloat16 covers a wide range of unused values...
class IF(hk.RNNCore):
"""
Integrate and Fire neuron model
Expand All @@ -111,11 +110,11 @@ def __call__(self, x, V):

return spikes, V

def initial_state(self, batch_size): # figure out how to make dynamic...
def initial_state(self, batch_size):
return jnp.zeros((batch_size,) + self.hidden_shape, dtype=jnp.float16)


class LIF(hk.RNNCore): # bfloat16 covers a wide range of unused values...
class LIF(hk.RNNCore):
"""
Leaky Integrate and Fire neuron model inspired by the implementation in
snnTorch:
Expand All @@ -142,7 +141,6 @@ def __init__(self, hidden_shape: tuple, beta=None, threshold=1,

def __call__(self, x, V):

# numerical stability gremlin...
beta = self.beta
if not beta:
beta = hk.get_parameter("beta", self.hidden_shape, dtype=jnp.float16,
Expand All @@ -155,11 +153,11 @@ def __call__(self, x, V):

return spikes, V

def initial_state(self, batch_size): # figure out how to make dynamic...
def initial_state(self, batch_size):
return jnp.zeros((batch_size,) + self.hidden_shape, dtype=jnp.float16)


class RLIF(hk.RNNCore): # bfloat16 covers a wide range of unused values...
class RLIF(hk.RNNCore):
"""
Recurrent LIF Neuron adapted from snnTorch:
Expand Down Expand Up @@ -193,49 +191,3 @@ def __call__(self, x, V):
def initial_state(self, batch_size):
return jnp.zeros((batch_size,) + self.hidden_shape, dtype=jnp.float16)

# Current Based (CuBa)
class SC(hk.RNNCore):
"""
Conductance based neuron modeling synaptic conductance.
Adapted from snnTorch:
https://snntorch.readthedocs.io/en/latest/snn.neurons_synaptic.html
"""

def __init__(self, hidden_shape, alpha=None, beta=None, threshold=1,
activation = Axon(),
name="SC"):
super().__init__(name=name)
self.hidden_shape = hidden_shape
self.alpha = alpha
self.beta = beta
self.threshold = threshold
self.act = activation

def __call__(self, x, VI):
V, I = jnp.split(VI, 2, -1)

alpha = self.alpha
beta = self.beta
# threshold adaptation
if not alpha:
alpha = hk.get_parameter("alpha", self.hidden_shape,
init=hk.initializers.TruncatedNormal(0.25, 0.5))
alpha = jax.nn.hard_sigmoid(alpha)
if not beta:
beta = hk.get_parameter("beta", self.hidden_shape,
init=hk.initializers.TruncatedNormal(0.25, 0.5))
beta =jax.nn.hard_sigmoid(beta)
# calculate whether spike is generated, and update membrane potential
spikes = self.act(V - self.threshold)
I = alpha*I + x
V = (beta*V + I - spikes*self.threshold).astype(jnp.float16) # cast may not be needed?

VI = jnp.concatenate([V,I], axis=-1, dtype=jnp.float16)
return spikes, VI

# this is probably borked with the shaping now.
def initial_state(self, batch_size):
return jnp.zeros((batch_size,) + tuple(2*s for s in self.hidden_shape), dtype=jnp.float16)

0 comments on commit 18deb99

Please sign in to comment.