Skip to content

Commit

Permalink
thesis scaries
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Aug 21, 2023
1 parent 7ca0be0 commit 4ba9f3a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
⚡🧠💻 Welcome to Spyx! 💻🧠⚡
============================
[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506)
[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506) [![PyPI version](https://badge.fury.io/py/spyx.svg)](https://badge.fury.io/py/spyx)
![README Art](spyx.png "Spyx")

Spyx is a compact spiking neural network library built on top of DeepMind's Haiku library.
Expand Down
20 changes: 13 additions & 7 deletions spyx/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.model_selection import LeaveOneGroupOut
from collections import namedtuple
from itertools import cycle

import numpy as np
import jax
Expand All @@ -14,7 +12,9 @@

State = namedtuple("State", "obs labels")

# should add a class that takes a numpy dataset and reshapes it and compresses it...

# This should be changed to a higher-order function
class shift_augment:
"""
Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.
Expand Down Expand Up @@ -45,17 +45,19 @@ def rate_code(data, steps, key, max_r=0.75):

data = jnp.array(data, dtype=jnp.float16)
unrolled_data = jnp.repeat(data, steps, axis=1)
return jax.random.bernoulli(key, unrolled_data*max_r).astype(jnp.int8)
return jax.random.bernoulli(key, unrolled_data*max_r).astype(jnp.uint8)


class MNIST_loader():
class MNIST_loader(): # change this so that it just returns either rate or temporal mnist...
"""
Dataloader for the MNIST dataset, right now it is rate encoded.
Dataloader for the MNIST dataset. The data is returned in a non-temporal format so instead of applying
jnp.unpack_bits on the data you need to apply the spyx.data.rate_code function to generate the SNN input.
"""

# Change this to allow a config dictionary of
def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, key=0, download_dir='./MNIST'):
def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, subsample_data=1, key=0, download_dir='./MNIST'):

self.key = jax.random.PRNGKey(key)
self.sample_T = time_steps
Expand Down Expand Up @@ -84,6 +86,10 @@ def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3,
random_state=0,
shuffle=True
)

# to help with trying to do neuroevolution since the full dataset is a bit much for evolving convnets...
train_indices = train_indices[:int(len(train_indices)*subsample_data)]
val_indicies = val_indices[:int(len(val_indices)*subsample_data)]


train_split = Subset(train_val_dataset, train_indices)
Expand All @@ -93,7 +99,7 @@ def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3,
collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))

x_train, y_train = next(train_dl)
self.x_train = jnp.array(x_train, dtype=jnp.uint8)
self.x_train = jnp.packbits(rate_code(jnp.array(x_train, dtype=jnp.uint8), self.sample_T, key), axis=2)
self.y_train = jnp.array(y_train, dtype=jnp.uint8)
############################

Expand Down
20 changes: 19 additions & 1 deletion spyx/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jax import tree_util as tree

### Change all of these to H.O.F.s

class silence_reg:
"""
Expand Down Expand Up @@ -42,7 +43,7 @@ class sparsity_reg:
"""
def __init__(self, max_spikes, norm=optax.huber_loss):
def _loss(x):
return norm(jnp.maximum(0, jnp.mean(x, axis=-1) - max_spikes))
return norm(jnp.maximum(0, jnp.mean(x, axis=-1) - max_spikes)) # this may not work for convolution layers....

def _flatten(x):
return jnp.reshape(x, (x.shape[0], -1))
Expand Down Expand Up @@ -70,6 +71,7 @@ def integral_accuracy(traces, targets):
return jnp.sum(preds == targets) / traces.shape[0], preds

# smoothing can be critical to the performance of your model...
# change this to be a higher-order function yielding a func with a set smoothing rate.
@jax.jit
def integral_crossentropy(traces, targets, smoothing=0.3):
"""
Expand All @@ -87,3 +89,19 @@ def integral_crossentropy(traces, targets, smoothing=0.3):
labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing)
return optax.softmax_cross_entropy(logits, labels).mean()

# convert to function that returns compiled function
def mse_spikerate(traces, targets, sparsity=0.25, smoothing=0.0):
"""
Calculate the mean squared error of the mean spike rate.
Allows for label smoothing to discourage silencing
the other neurons in the readout layer.
Attributes:
traces: the output of the final layer of the SNN
targets: the integer labels for each class
smoothing: [optional] rate at which to smooth labels.
"""

logits = jnp.mean(traces, axis=-2) # time axis.
labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing)
return jnp.mean(optax.squared_error(logits, labels * sparsity))
9 changes: 9 additions & 0 deletions spyx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
import haiku as hk
from .axn import Axon

# need to add shape checking/warning
def PopulationCode(num_classes):
"""
Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of
the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.
"""
def _pop_code(x):
return jnp.sum(jnp.reshape(x, (-1,num_classes)), axis=-1)
return jax.jit(_pop_code)

class ALIF(hk.RNNCore):
"""
Expand Down

0 comments on commit 4ba9f3a

Please sign in to comment.