Skip to content

Commit

Permalink
Merge pull request #21 from agosztolai/version2
Browse files Browse the repository at this point in the history
Version2
  • Loading branch information
agosztolai authored Nov 27, 2023
2 parents 38cb48a + f8551f4 commit 5717092
Show file tree
Hide file tree
Showing 7 changed files with 1,862 additions and 11 deletions.
2 changes: 2 additions & 0 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ out_channels: 3 # number of output channels (if null, then =hidden_channels)
bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: False # batch normalisation
emb_norm: False # spherical output
skip_connections: True # use skips in MLP

# other params
seed: 0 # seed for reproducibility
Expand Down
28 changes: 27 additions & 1 deletion MARBLE/layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
"""Layer module."""
import torch
from torch import nn
from torch.nn.functional import normalize
from torch.nn.functional import normalize, relu
from torch_geometric.nn.conv import MessagePassing

from MARBLE import geometry as g

class SkipMLP(nn.Module):
def __init__(self, channel_list, dropout=0.0, bias=True):
super(SkipMLP, self).__init__()
self.layers = nn.ModuleList()
self.in_channels = channel_list[0]
for i in range(len(channel_list) - 1):
self.layers.append(nn.Linear(channel_list[i], channel_list[i + 1], bias=bias))
self.layers.append(nn.Dropout(dropout))

# Output layer adjustment for concatenated skip connection
final_out_features = channel_list[-1] + channel_list[0]
self.output_layer = nn.Linear(final_out_features, channel_list[-1], bias=bias)

def forward(self, x):
identity = x
for layer in self.layers:
if isinstance(layer, nn.Linear):
x = relu(layer(x))
else:
x = layer(x)

# Concatenate the input (identity) with the output
x = torch.cat([identity, x], dim=1)
x = self.output_layer(x)
return x


class Diffusion(nn.Module):
"""Diffusion with learned t."""
Expand Down
37 changes: 27 additions & 10 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class net(nn.Module):
out_channels: number of output channels (if null, then =hidden_channels) (default=3)
bias: learn bias parameters in MLP (default=True)
vec_norm: normalise features to unit length (default=False)
emb_norm: normalise MLP output to unit length (default=False)
batch_norm: batch normalisation (default=False)
seed: seed for reproducibility (default=0)
processes: number of cpus (default=1)
Expand All @@ -62,7 +63,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
if loadpath is not None:
if Path(loadpath).is_dir():
loadpath = max(glob.glob(f"{loadpath}/best_model*"))
self.params = torch.load(loadpath)["params"]
self.params = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))["params"]
else:
if params is not None:
if isinstance(params, str) and Path(params).exists():
Expand Down Expand Up @@ -145,6 +146,8 @@ def check_parameters(self, data):
"bias",
"batch_norm",
"vec_norm",
"emb_norm",
"skip_connections",
"seed",
"n_sampled_nb",
"processes",
Expand Down Expand Up @@ -201,12 +204,23 @@ def setup_layers(self):
+ [self.params["out_channels"]]
)

self.enc = MLP(
channel_list=channel_list,
dropout=self.params["dropout"],
norm=self.params["batch_norm"],
bias=self.params["bias"],
)
if self.params['skip_connections']:
self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)
else:
self.enc = MLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)






def forward(self, data, n_id, adjs=None):
"""Forward pass.
Expand Down Expand Up @@ -267,10 +281,13 @@ def forward(self, data, n_id, adjs=None):
if self.params["include_positions"]:
out = torch.hstack(
[data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable
)
)

emb = self.enc(out)

if self.params['emb_norm']: # spherical output
emb = F.normalize(emb)

return emb, mask[: size[1]]

def evaluate(self, data):
Expand Down Expand Up @@ -398,7 +415,7 @@ def load_model(self, loadpath):
Args:
loadpath: directory with models to load best model, or specific model path
"""
checkpoint = torch.load(loadpath)
checkpoint = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
Expand Down
745 changes: 745 additions & 0 deletions examples/rat_task/Demo_consistency.ipynb

Large diffs are not rendered by default.

966 changes: 966 additions & 0 deletions examples/rat_task/Demo_decoding.ipynb

Large diffs are not rendered by default.

Binary file added examples/rat_task/rat_data.pkl
Binary file not shown.
95 changes: 95 additions & 0 deletions examples/rat_task/rat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys
import numpy as np
import matplotlib.pyplot as plt

from elephant.kernels import GaussianKernel
from elephant.statistics import instantaneous_rate
from quantities import ms
import neo

from sklearn.decomposition import PCA
import sklearn

import MARBLE
import cebra

def prepare_marble(spikes, labels, pca=None, pca_n=10, skip=1):

s_interval = 1

gk = GaussianKernel(10 * ms)
rates = []
for sp in spikes:
sp_times = np.where(sp)[0]
st = neo.SpikeTrain(sp_times, units="ms", t_stop=len(sp))
r = instantaneous_rate(st, kernel=gk, sampling_period=s_interval * ms).magnitude
rates.append(r.T)

rates = np.vstack(rates)

if pca is None:
pca = PCA(n_components=pca_n)
rates_pca = pca.fit_transform(rates.T)
else:
rates_pca = pca.transform(rates.T)

vel_rates_pca = np.diff(rates_pca, axis=0)
print(pca.explained_variance_ratio_)

rates_pca = rates_pca[:-1,:] # skip last

labels = labels[:rates_pca.shape[0]]

data = MARBLE.construct_dataset(
rates_pca,
features=vel_rates_pca,
k=15,
stop_crit=0.0,
delta=1.5,
compute_laplacian=True,
local_gauges=False,
)

return data, labels, pca


def find_sequences(vector):
sequences = []
start_index = 0

for i in range(1, len(vector)):
if vector[i] != vector[i - 1]:
sequences.append((start_index, i - 1))
start_index = i

# Add the last sequence
sequences.append((start_index, len(vector) - 1))

return sequences

# Define decoding function with kNN decoder. For a simple demo, we will use the fixed number of neighbors 36.
def decoding_pos_dir(embedding_train, embedding_test, label_train, label_test):
pos_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")
dir_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")

pos_decoder.fit(embedding_train, label_train[:,0])
dir_decoder.fit(embedding_train, label_train[:,1])

pos_pred = pos_decoder.predict(embedding_test)
dir_pred = dir_decoder.predict(embedding_test)

prediction = np.stack([pos_pred, dir_pred],axis = 1)

test_score = sklearn.metrics.r2_score(label_test[:,:2], prediction)
pos_test_err = np.median(abs(prediction[:,0] - label_test[:, 0]))
pos_test_score = sklearn.metrics.r2_score(label_test[:, 0], prediction[:,0])

prediction_error = abs(prediction[:,0] - label_test[:, 0])

# prediction error by back and forth
sequences = find_sequences(label_test[:,1])
errors = []
for seq in sequences:
errors.append(np.median(abs(prediction[seq,0] - label_test[seq, 0])))

return test_score, pos_test_err, pos_test_score, prediction, prediction_error, np.array(errors)

0 comments on commit 5717092

Please sign in to comment.