Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version2 #20

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out_channels: 3 # number of output channels (if null, then =hidden_channels)
bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: False # batch normalisation
emb_norm: False # spherical output

# other params
seed: 0 # seed for reproducibility
Expand Down
29 changes: 29 additions & 0 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@

from MARBLE import geometry as g

class SkipMLP(nn.Module):
""" MLP with skip connections """

def __init__(self, channel_list, dropout=0.0, bias=True):
super(SkipMLP, self).__init__()
assert len(channel_list) > 1, "Channel list must have at least two elements for an MLP."
self.layers = nn.ModuleList()
self.dropout = dropout
self.in_channels = channel_list[0]
for i in range(len(channel_list) - 1):
self.layers.append(nn.Linear(channel_list[i], channel_list[i+1], bias=bias))
if i < len(channel_list) - 2: # Don't add activation or dropout to the last layer
self.layers.append(nn.ReLU(inplace=True))
if dropout > 0:
self.layers.append(nn.Dropout(dropout))

def forward(self, x):
identity = x
for layer in self.layers:
if isinstance(layer, nn.Linear):
if x.shape[1] == layer.weight.shape[1]: # Check if skip connection is possible
identity = x # Save identity for skip connection
x = layer(x)
if x.shape[1] == identity.shape[1]: # Apply skip connection if shapes match
x += identity
else:
x = layer(x) # Apply activation or dropout
return x


class Diffusion(nn.Module):
"""Diffusion with learned t."""
Expand Down
27 changes: 21 additions & 6 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class net(nn.Module):
out_channels: number of output channels (if null, then =hidden_channels) (default=3)
bias: learn bias parameters in MLP (default=True)
vec_norm: normalise features to unit length (default=False)
emb_norm: normalise MLP output to unit length (default=False)
batch_norm: batch normalisation (default=False)
seed: seed for reproducibility (default=0)
processes: number of cpus (default=1)
Expand All @@ -62,7 +63,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
if loadpath is not None:
if Path(loadpath).is_dir():
loadpath = max(glob.glob(f"{loadpath}/best_model*"))
self.params = torch.load(loadpath)["params"]
self.params = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))["params"]
else:
if params is not None:
if isinstance(params, str) and Path(params).exists():
Expand Down Expand Up @@ -145,6 +146,7 @@ def check_parameters(self, data):
"bias",
"batch_norm",
"vec_norm",
"emb_norm",
"seed",
"n_sampled_nb",
"processes",
Expand Down Expand Up @@ -201,12 +203,22 @@ def setup_layers(self):
+ [self.params["out_channels"]]
)

self.enc = MLP(
# self.enc = MLP(
# channel_list=channel_list,
# dropout=self.params["dropout"],
# #norm=self.params["batch_norm"],
# bias=self.params["bias"],
# )

self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
norm=self.params["batch_norm"],
#norm=self.params["batch_norm"],
bias=self.params["bias"],
)




def forward(self, data, n_id, adjs=None):
"""Forward pass.
Expand Down Expand Up @@ -267,10 +279,13 @@ def forward(self, data, n_id, adjs=None):
if self.params["include_positions"]:
out = torch.hstack(
[data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable
)
)

emb = self.enc(out)

if self.params['emb_norm']: # spherical output
emb = F.normalize(emb)

return emb, mask[: size[1]]

def evaluate(self, data):
Expand Down Expand Up @@ -398,7 +413,7 @@ def load_model(self, loadpath):
Args:
loadpath: directory with models to load best model, or specific model path
"""
checkpoint = torch.load(loadpath)
checkpoint = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
Expand Down
745 changes: 745 additions & 0 deletions examples/rat_task/Demo_consistency.ipynb

Large diffs are not rendered by default.

966 changes: 966 additions & 0 deletions examples/rat_task/Demo_decoding.ipynb

Large diffs are not rendered by default.

Binary file added examples/rat_task/rat_data.pkl
Binary file not shown.
95 changes: 95 additions & 0 deletions examples/rat_task/rat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys
import numpy as np
import matplotlib.pyplot as plt

from elephant.kernels import GaussianKernel
from elephant.statistics import instantaneous_rate
from quantities import ms
import neo

from sklearn.decomposition import PCA
import sklearn

import MARBLE
import cebra

def prepare_marble(spikes, labels, pca=None, pca_n=10, skip=1):

s_interval = 1

gk = GaussianKernel(10 * ms)
rates = []
for sp in spikes:
sp_times = np.where(sp)[0]
st = neo.SpikeTrain(sp_times, units="ms", t_stop=len(sp))
r = instantaneous_rate(st, kernel=gk, sampling_period=s_interval * ms).magnitude
rates.append(r.T)

rates = np.vstack(rates)

if pca is None:
pca = PCA(n_components=pca_n)
rates_pca = pca.fit_transform(rates.T)
else:
rates_pca = pca.transform(rates.T)

vel_rates_pca = np.diff(rates_pca, axis=0)
print(pca.explained_variance_ratio_)

rates_pca = rates_pca[:-1,:] # skip last

labels = labels[:rates_pca.shape[0]]

data = MARBLE.construct_dataset(
rates_pca,
features=vel_rates_pca,
k=15,
stop_crit=0.0,
delta=1.5,
compute_laplacian=True,
local_gauges=False,
)

return data, labels, pca


def find_sequences(vector):
sequences = []
start_index = 0

for i in range(1, len(vector)):
if vector[i] != vector[i - 1]:
sequences.append((start_index, i - 1))
start_index = i

# Add the last sequence
sequences.append((start_index, len(vector) - 1))

return sequences

# Define decoding function with kNN decoder. For a simple demo, we will use the fixed number of neighbors 36.
def decoding_pos_dir(embedding_train, embedding_test, label_train, label_test):
pos_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")
dir_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")

pos_decoder.fit(embedding_train, label_train[:,0])
dir_decoder.fit(embedding_train, label_train[:,1])

pos_pred = pos_decoder.predict(embedding_test)
dir_pred = dir_decoder.predict(embedding_test)

prediction = np.stack([pos_pred, dir_pred],axis = 1)

test_score = sklearn.metrics.r2_score(label_test[:,:2], prediction)
pos_test_err = np.median(abs(prediction[:,0] - label_test[:, 0]))
pos_test_score = sklearn.metrics.r2_score(label_test[:, 0], prediction[:,0])

prediction_error = abs(prediction[:,0] - label_test[:, 0])

# prediction error by back and forth
sequences = find_sequences(label_test[:,1])
errors = []
for seq in sequences:
errors.append(np.median(abs(prediction[seq,0] - label_test[seq, 0])))

return test_score, pos_test_err, pos_test_score, prediction, prediction_error, np.array(errors)
Loading