Skip to content

Commit

Permalink
CEBRA MARBLE comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Nov 4, 2024
1 parent 4c4f60a commit 5447765
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 3 deletions.
5 changes: 3 additions & 2 deletions examples/macaque_reaching/run_cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.signal import savgol_filter
from sklearn.decomposition import PCA
from sklearn.neighbors import LocalOutlierFactor
from run_marble import load_data, fit_pca, format_data
from macaque_reaching_helpers import fit_pca, format_data
from tqdm import tqdm
from cebra import CEBRA

Expand All @@ -17,7 +17,8 @@ def main():
data_file = "data/rate_data_20ms.pkl"
metadata_file = "data/trial_ids.pkl"

rates, trial_ids = load_data(data_file, metadata_file)
rates = pickle.load(open(data_file, "rb"))
trial_ids = pickle.load(open(metadata_file, "rb"))

# defining the set of conditions
conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]
Expand Down
3 changes: 2 additions & 1 deletion examples/macaque_reaching/run_marble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def main():
data_file = "data/rate_data_20ms.pkl"
metadata_file = "data/trial_ids.pkl"

rates, trial_ids = load_data(data_file, metadata_file)
rates = pickle.load(open(data_file, "rb"))
trial_ids = pickle.load(open(metadata_file, "rb"))

# defining the set of conditions
conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a700a730-16ad-4571-8507-5afe7f029d7b",
"metadata": {},
"source": [
"# Example workflow and comparison between MARBLE with CEBRA"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33158156-a63b-4dc7-8ab1-ef0b0986bf88",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"! pip install ipympl\n",
"%matplotlib widget\n",
"\n",
"import numpy as np\n",
"import pickle\n",
"from macaque_reaching_helpers import fit_pca, format_data\n",
"import matplotlib as mpl\n",
"\n",
"import MARBLE\n",
"\n",
"!pip install cebra\n",
"from cebra import CEBRA"
]
},
{
"cell_type": "markdown",
"id": "67c8f05c-c18d-4be4-b71d-4a7de237e68b",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "markdown",
"id": "119f22a5-f780-48ed-ae05-0937fed60ddc",
"metadata": {},
"source": [
"This part is data specific and you will need to adapt it to your own dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03b7a91c-617e-4070-8525-abd7aabcadec",
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969883 -O data/rate_data_20ms_100ms.pkl\n",
"\n",
"with open('data/rate_data_20ms_100ms.pkl', 'rb') as handle:\n",
" rates = pickle.load(handle)\n",
"\n",
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963200 -O data/trial_ids.pkl\n",
"\n",
"with open('data/trial_ids.pkl', 'rb') as handle:\n",
" trial_ids = pickle.load(handle)\n",
"\n",
"conditions = [\"DownLeft\", \"Left\", \"UpLeft\", \"Up\", \"UpRight\", \"Right\", \"DownRight\"]"
]
},
{
"cell_type": "markdown",
"id": "135c737c-7ba5-47d2-9d20-90d4a7b8f50a",
"metadata": {},
"source": [
"## Linear dimensionality reduction and filtering of data. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b9b0b28-eb95-4f34-9afb-5d08a9ba88fe",
"metadata": {},
"outputs": [],
"source": [
"pca_n = 5\n",
"filter_data = True\n",
"day = 5 #load one session\n",
"\n",
"pca = fit_pca(rates, day, conditions, filter_data=filter_data, pca_n=pca_n)\n",
" \n",
"pos, vel, timepoints, condition_labels, trial_indexes = format_data(rates, \n",
" trial_ids,\n",
" day, \n",
" conditions, \n",
" pca=pca,\n",
" filter_data=filter_data)"
]
},
{
"cell_type": "markdown",
"id": "b0dcca1b-c74a-4726-a70e-2a63f85d6db4",
"metadata": {},
"source": [
"## Run CEBRA"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de9b27b6-81a6-4d01-8c29-1f15122fb823",
"metadata": {},
"outputs": [],
"source": [
"cebra_model = CEBRA(model_architecture='offset10-model',\n",
" batch_size=512,\n",
" learning_rate=0.0001,\n",
" temperature=1,\n",
" output_dimension=3,\n",
" max_iterations=5000,\n",
" distance='euclidean',\n",
" conditional='time_delta',\n",
" device='cpu',\n",
" verbose=True,\n",
" time_offsets=10)\n",
"\n",
"pos_all = np.vstack(pos)\n",
"condition_labels = np.hstack(condition_labels)\n",
" \n",
"cebra_model.fit(pos_all, condition_labels)\n",
"cebra_pos = cebra_model.transform(pos_all)"
]
},
{
"cell_type": "markdown",
"id": "c2eae737-f1b4-41ba-9c56-4e5f0cf7743f",
"metadata": {},
"source": [
"## Run MARBLE"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b16f9b30-916f-4a3a-800e-2b4a46135aaf",
"metadata": {},
"outputs": [],
"source": [
"data = MARBLE.construct_dataset(\n",
" anchor=pos,\n",
" vector=vel,\n",
" k=30,\n",
" spacing=0.0,\n",
" delta=1.5,\n",
")\n",
"\n",
"params = {\n",
" \"epochs\": 120, # optimisation epochs\n",
" \"order\": 2, # order of derivatives\n",
" \"hidden_channels\": 100, # number of internal dimensions in MLP\n",
" \"out_channels\": 3, \n",
" \"inner_product_features\": False,\n",
" \"diffusion\": True,\n",
"}\n",
"\n",
"model = MARBLE.net(data, params=params)\n",
"\n",
"model.fit(data, outdir=\"data/session_{}_20ms\".format(day))\n",
"data = model.transform(data)"
]
},
{
"cell_type": "markdown",
"id": "1b9cd38a-643f-48dc-af76-f26539d40f4f",
"metadata": {},
"source": [
"## Plot embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72e2fc60-c21a-4c6f-819d-e85d31761832",
"metadata": {},
"outputs": [],
"source": [
"label = np.hstack(condition_labels)\n",
"\n",
"colors = mpl.cm.viridis(np.linspace(0, 1, 7))\n",
"fig = plt.figure(figsize=(10, 8))\n",
"\n",
"\n",
"ax1 = fig.add_subplot(121, projection='3d')\n",
"\n",
"emb = cebra_pos\n",
"\n",
"for i in range(7):\n",
" # Filter points by condition label\n",
" indices = label == i\n",
" ax1.scatter(\n",
" emb[indices, 0], # x-coordinates\n",
" emb[indices, 1], # y-coordinates\n",
" emb[indices, 2], # z-coordinates\n",
" s=10, # marker size\n",
" color=colors[i], # color for each condition\n",
" label=f'Condition {i}',\n",
" alpha=0.8\n",
" )\n",
"\n",
"ax1.grid(False)\n",
"ax1.set_xticks([])\n",
"ax1.set_yticks([])\n",
"ax1.set_zticks([])\n",
"ax1.legend()\n",
"\n",
"ax2 = fig.add_subplot(121, projection='3d')\n",
"\n",
"emb = data.emb\n",
"\n",
"for i in range(7):\n",
" # Filter points by condition label\n",
" indices = label == i\n",
" ax2.scatter(\n",
" emb[indices, 0], # x-coordinates\n",
" emb[indices, 1], # y-coordinates\n",
" emb[indices, 2], # z-coordinates\n",
" s=10, # marker size\n",
" color=colors[i], # color for each condition\n",
" label=f'Condition {i}',\n",
" alpha=0.8\n",
" )\n",
"\n",
"ax2.grid(False)\n",
"ax2.set_xticks([])\n",
"ax2.set_yticks([])\n",
"ax2.set_zticks([])\n",
"ax2.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eb30c9f-918d-4d27-a012-fbbb63cfd075",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 5447765

Please sign in to comment.