diff --git a/examples/macaque_reaching/run_cebra.py b/examples/macaque_reaching/run_cebra.py index 35067a2c..712327f8 100644 --- a/examples/macaque_reaching/run_cebra.py +++ b/examples/macaque_reaching/run_cebra.py @@ -5,7 +5,7 @@ from scipy.signal import savgol_filter from sklearn.decomposition import PCA from sklearn.neighbors import LocalOutlierFactor -from run_marble import load_data, fit_pca, format_data +from macaque_reaching_helpers import fit_pca, format_data from tqdm import tqdm from cebra import CEBRA @@ -17,7 +17,8 @@ def main(): data_file = "data/rate_data_20ms.pkl" metadata_file = "data/trial_ids.pkl" - rates, trial_ids = load_data(data_file, metadata_file) + rates = pickle.load(open(data_file, "rb")) + trial_ids = pickle.load(open(metadata_file, "rb")) # defining the set of conditions conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"] diff --git a/examples/macaque_reaching/run_marble.py b/examples/macaque_reaching/run_marble.py index f39575fd..649636b8 100644 --- a/examples/macaque_reaching/run_marble.py +++ b/examples/macaque_reaching/run_marble.py @@ -11,7 +11,8 @@ def main(): data_file = "data/rate_data_20ms.pkl" metadata_file = "data/trial_ids.pkl" - rates, trial_ids = load_data(data_file, metadata_file) + rates = pickle.load(open(data_file, "rb")) + trial_ids = pickle.load(open(metadata_file, "rb")) # defining the set of conditions conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"] diff --git a/examples/macaque_reaching/workflow_and_comparison_of_MARBLE_and_CEBRA.ipynb b/examples/macaque_reaching/workflow_and_comparison_of_MARBLE_and_CEBRA.ipynb new file mode 100644 index 00000000..23f3d6f7 --- /dev/null +++ b/examples/macaque_reaching/workflow_and_comparison_of_MARBLE_and_CEBRA.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a700a730-16ad-4571-8507-5afe7f029d7b", + "metadata": {}, + "source": [ + "# Example workflow and comparison between MARBLE with CEBRA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33158156-a63b-4dc7-8ab1-ef0b0986bf88", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "! pip install ipympl\n", + "%matplotlib widget\n", + "\n", + "import numpy as np\n", + "import pickle\n", + "from macaque_reaching_helpers import fit_pca, format_data\n", + "import matplotlib as mpl\n", + "\n", + "import MARBLE\n", + "\n", + "!pip install cebra\n", + "from cebra import CEBRA" + ] + }, + { + "cell_type": "markdown", + "id": "67c8f05c-c18d-4be4-b71d-4a7de237e68b", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "markdown", + "id": "119f22a5-f780-48ed-ae05-0937fed60ddc", + "metadata": {}, + "source": [ + "This part is data specific and you will need to adapt it to your own dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03b7a91c-617e-4070-8525-abd7aabcadec", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969883 -O data/rate_data_20ms_100ms.pkl\n", + "\n", + "with open('data/rate_data_20ms_100ms.pkl', 'rb') as handle:\n", + " rates = pickle.load(handle)\n", + "\n", + "!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963200 -O data/trial_ids.pkl\n", + "\n", + "with open('data/trial_ids.pkl', 'rb') as handle:\n", + " trial_ids = pickle.load(handle)\n", + "\n", + "conditions = [\"DownLeft\", \"Left\", \"UpLeft\", \"Up\", \"UpRight\", \"Right\", \"DownRight\"]" + ] + }, + { + "cell_type": "markdown", + "id": "135c737c-7ba5-47d2-9d20-90d4a7b8f50a", + "metadata": {}, + "source": [ + "## Linear dimensionality reduction and filtering of data. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b9b0b28-eb95-4f34-9afb-5d08a9ba88fe", + "metadata": {}, + "outputs": [], + "source": [ + "pca_n = 5\n", + "filter_data = True\n", + "day = 5 #load one session\n", + "\n", + "pca = fit_pca(rates, day, conditions, filter_data=filter_data, pca_n=pca_n)\n", + " \n", + "pos, vel, timepoints, condition_labels, trial_indexes = format_data(rates, \n", + " trial_ids,\n", + " day, \n", + " conditions, \n", + " pca=pca,\n", + " filter_data=filter_data)" + ] + }, + { + "cell_type": "markdown", + "id": "b0dcca1b-c74a-4726-a70e-2a63f85d6db4", + "metadata": {}, + "source": [ + "## Run CEBRA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de9b27b6-81a6-4d01-8c29-1f15122fb823", + "metadata": {}, + "outputs": [], + "source": [ + "cebra_model = CEBRA(model_architecture='offset10-model',\n", + " batch_size=512,\n", + " learning_rate=0.0001,\n", + " temperature=1,\n", + " output_dimension=3,\n", + " max_iterations=5000,\n", + " distance='euclidean',\n", + " conditional='time_delta',\n", + " device='cpu',\n", + " verbose=True,\n", + " time_offsets=10)\n", + "\n", + "pos_all = np.vstack(pos)\n", + "condition_labels = np.hstack(condition_labels)\n", + " \n", + "cebra_model.fit(pos_all, condition_labels)\n", + "cebra_pos = cebra_model.transform(pos_all)" + ] + }, + { + "cell_type": "markdown", + "id": "c2eae737-f1b4-41ba-9c56-4e5f0cf7743f", + "metadata": {}, + "source": [ + "## Run MARBLE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b16f9b30-916f-4a3a-800e-2b4a46135aaf", + "metadata": {}, + "outputs": [], + "source": [ + "data = MARBLE.construct_dataset(\n", + " anchor=pos,\n", + " vector=vel,\n", + " k=30,\n", + " spacing=0.0,\n", + " delta=1.5,\n", + ")\n", + "\n", + "params = {\n", + " \"epochs\": 120, # optimisation epochs\n", + " \"order\": 2, # order of derivatives\n", + " \"hidden_channels\": 100, # number of internal dimensions in MLP\n", + " \"out_channels\": 3, \n", + " \"inner_product_features\": False,\n", + " \"diffusion\": True,\n", + "}\n", + "\n", + "model = MARBLE.net(data, params=params)\n", + "\n", + "model.fit(data, outdir=\"data/session_{}_20ms\".format(day))\n", + "data = model.transform(data)" + ] + }, + { + "cell_type": "markdown", + "id": "1b9cd38a-643f-48dc-af76-f26539d40f4f", + "metadata": {}, + "source": [ + "## Plot embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72e2fc60-c21a-4c6f-819d-e85d31761832", + "metadata": {}, + "outputs": [], + "source": [ + "label = np.hstack(condition_labels)\n", + "\n", + "colors = mpl.cm.viridis(np.linspace(0, 1, 7))\n", + "fig = plt.figure(figsize=(10, 8))\n", + "\n", + "\n", + "ax1 = fig.add_subplot(121, projection='3d')\n", + "\n", + "emb = cebra_pos\n", + "\n", + "for i in range(7):\n", + " # Filter points by condition label\n", + " indices = label == i\n", + " ax1.scatter(\n", + " emb[indices, 0], # x-coordinates\n", + " emb[indices, 1], # y-coordinates\n", + " emb[indices, 2], # z-coordinates\n", + " s=10, # marker size\n", + " color=colors[i], # color for each condition\n", + " label=f'Condition {i}',\n", + " alpha=0.8\n", + " )\n", + "\n", + "ax1.grid(False)\n", + "ax1.set_xticks([])\n", + "ax1.set_yticks([])\n", + "ax1.set_zticks([])\n", + "ax1.legend()\n", + "\n", + "ax2 = fig.add_subplot(121, projection='3d')\n", + "\n", + "emb = data.emb\n", + "\n", + "for i in range(7):\n", + " # Filter points by condition label\n", + " indices = label == i\n", + " ax2.scatter(\n", + " emb[indices, 0], # x-coordinates\n", + " emb[indices, 1], # y-coordinates\n", + " emb[indices, 2], # z-coordinates\n", + " s=10, # marker size\n", + " color=colors[i], # color for each condition\n", + " label=f'Condition {i}',\n", + " alpha=0.8\n", + " )\n", + "\n", + "ax2.grid(False)\n", + "ax2.set_xticks([])\n", + "ax2.set_yticks([])\n", + "ax2.set_zticks([])\n", + "ax2.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8eb30c9f-918d-4d27-a012-fbbb63cfd075", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}