diff --git a/README.md b/README.md index 40c0ef0..4150e03 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ⚡🧠💻 Welcome to Spyx! 💻🧠⚡ ============================ -[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506) +[![DOI](https://zenodo.org/badge/656877506.svg)](https://zenodo.org/badge/latestdoi/656877506) [![PyPI version](https://badge.fury.io/py/spyx.svg)](https://badge.fury.io/py/spyx) ![README Art](spyx.png "Spyx") Spyx is a compact spiking neural network library built on top of DeepMind's Haiku library. diff --git a/spyx/data.py b/spyx/data.py index 4695738..d420ce5 100644 --- a/spyx/data.py +++ b/spyx/data.py @@ -3,9 +3,7 @@ import torchvision as tv from torch.utils.data import DataLoader, Subset from sklearn.model_selection import train_test_split -from sklearn.model_selection import LeaveOneGroupOut from collections import namedtuple -from itertools import cycle import numpy as np import jax @@ -14,7 +12,9 @@ State = namedtuple("State", "obs labels") +# should add a class that takes a numpy dataset and reshapes it and compresses it... +# This should be changed to a higher-order function class shift_augment: """ Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount. @@ -45,17 +45,19 @@ def rate_code(data, steps, key, max_r=0.75): data = jnp.array(data, dtype=jnp.float16) unrolled_data = jnp.repeat(data, steps, axis=1) - return jax.random.bernoulli(key, unrolled_data*max_r).astype(jnp.int8) + return jax.random.bernoulli(key, unrolled_data*max_r).astype(jnp.uint8) -class MNIST_loader(): +class MNIST_loader(): # change this so that it just returns either rate or temporal mnist... """ - Dataloader for the MNIST dataset, right now it is rate encoded. + Dataloader for the MNIST dataset. The data is returned in a non-temporal format so instead of applying + jnp.unpack_bits on the data you need to apply the spyx.data.rate_code function to generate the SNN input. + """ # Change this to allow a config dictionary of - def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, key=0, download_dir='./MNIST'): + def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, subsample_data=1, key=0, download_dir='./MNIST'): self.key = jax.random.PRNGKey(key) self.sample_T = time_steps @@ -84,6 +86,10 @@ def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, random_state=0, shuffle=True ) + + # to help with trying to do neuroevolution since the full dataset is a bit much for evolving convnets... + train_indices = train_indices[:int(len(train_indices)*subsample_data)] + val_indicies = val_indices[:int(len(val_indices)*subsample_data)] train_split = Subset(train_val_dataset, train_indices) @@ -93,7 +99,7 @@ def __init__(self, time_steps=64, max_rate = 0.75, batch_size=32, val_size=0.3, collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False)) x_train, y_train = next(train_dl) - self.x_train = jnp.array(x_train, dtype=jnp.uint8) + self.x_train = jnp.packbits(rate_code(jnp.array(x_train, dtype=jnp.uint8), self.sample_T, key), axis=2) self.y_train = jnp.array(y_train, dtype=jnp.uint8) ############################ diff --git a/spyx/fn.py b/spyx/fn.py index 37d24f7..805d2d3 100644 --- a/spyx/fn.py +++ b/spyx/fn.py @@ -4,6 +4,7 @@ from jax import tree_util as tree +### Change all of these to H.O.F.s class silence_reg: """ @@ -42,7 +43,7 @@ class sparsity_reg: """ def __init__(self, max_spikes, norm=optax.huber_loss): def _loss(x): - return norm(jnp.maximum(0, jnp.mean(x, axis=-1) - max_spikes)) + return norm(jnp.maximum(0, jnp.mean(x, axis=-1) - max_spikes)) # this may not work for convolution layers.... def _flatten(x): return jnp.reshape(x, (x.shape[0], -1)) @@ -70,6 +71,7 @@ def integral_accuracy(traces, targets): return jnp.sum(preds == targets) / traces.shape[0], preds # smoothing can be critical to the performance of your model... +# change this to be a higher-order function yielding a func with a set smoothing rate. @jax.jit def integral_crossentropy(traces, targets, smoothing=0.3): """ @@ -87,3 +89,19 @@ def integral_crossentropy(traces, targets, smoothing=0.3): labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing) return optax.softmax_cross_entropy(logits, labels).mean() +# convert to function that returns compiled function +def mse_spikerate(traces, targets, sparsity=0.25, smoothing=0.0): + """ + Calculate the mean squared error of the mean spike rate. + Allows for label smoothing to discourage silencing + the other neurons in the readout layer. + + Attributes: + traces: the output of the final layer of the SNN + targets: the integer labels for each class + smoothing: [optional] rate at which to smooth labels. + """ + + logits = jnp.mean(traces, axis=-2) # time axis. + labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing) + return jnp.mean(optax.squared_error(logits, labels * sparsity)) diff --git a/spyx/nn.py b/spyx/nn.py index 51a5042..704f9d0 100644 --- a/spyx/nn.py +++ b/spyx/nn.py @@ -3,6 +3,15 @@ import haiku as hk from .axn import Axon +# need to add shape checking/warning +def PopulationCode(num_classes): + """ + Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of + the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count. + """ + def _pop_code(x): + return jnp.sum(jnp.reshape(x, (-1,num_classes)), axis=-1) + return jax.jit(_pop_code) class ALIF(hk.RNNCore): """