Skip to content

Commit

Permalink
macaque example
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Nov 28, 2023
1 parent ba28dfb commit 36177ec
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 162 deletions.
2 changes: 1 addition & 1 deletion MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def vector_diffusion(x, t, method="spectral", Lc=None):
return out


def compute_eigendecomposition(A, k=50, eps=1e-8):
def compute_eigendecomposition(A, k=None, eps=1e-8):
"""Eigendecomposition of a square matrix A.
Args:
Expand Down
71 changes: 33 additions & 38 deletions examples/macaque_reaching/decoding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,42 @@
"\n",
"This notebook takes the outputs from training MARBLE on macaque data with 20ms bin of spiking rates. It then performs decoding of the kinematic results provided publicly by the LFADS authors.\n",
"\n",
"\n",
"There are three main results files of interest, differing only in the output dimension of the MARBLE architecture:\n",
"1. Marble trained with 3-dim output dimension: marble_embeddings_20ms_out3.pkl - https://dataverse.harvard.edu/api/access/datafile/7062022 \n",
"2. Marble trained with 20-dim output dimension: marble_embeddings_20ms_out20.pkl - https://dataverse.harvard.edu/api/access/datafile/7062023 \n",
"3. Marble trained with 3-dim output dimension with rotation invariant features: marble_embeddings_20ms_rotation_invariant.pkl - https://dataverse.harvard.edu/api/access/datafile/7062087 \n",
"\n",
"\n",
"A fourth file containing the LFADS embeddings and the kinematics data:\n",
"1. kinematics.pkl - https://dataverse.harvard.edu/api/access/datafile/6969885 \n",
"\n",
"I'd like to thank the authors of LFADS for making this data accessible and answering our questions about the data!\n"
"We would like to thank the authors of LFADS for making this data accessible and answering our questions about the data!"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "cd350b67",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: statannotations in /Users/arnaudon/base/lib/python3.10/site-packages (0.5.0)\n",
"Requirement already satisfied: matplotlib>=2.2.2 in /Users/arnaudon/base/lib/python3.10/site-packages (from statannotations) (3.7.1)\n",
"Requirement already satisfied: seaborn<0.12,>=0.9.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from statannotations) (0.11.2)\n",
"Requirement already satisfied: numpy>=1.12.1 in /Users/arnaudon/base/lib/python3.10/site-packages (from statannotations) (1.23.2)\n",
"Requirement already satisfied: scipy>=1.1.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from statannotations) (1.9.0)\n",
"Requirement already satisfied: pandas>=0.23.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from statannotations) (1.4.3)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (21.3)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (3.0.9)\n",
"Requirement already satisfied: cycler>=0.10 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (0.11.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (1.4.4)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (4.37.1)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (1.0.7)\n",
"Requirement already satisfied: pillow>=6.2.0 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (9.2.0)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/arnaudon/base/lib/python3.10/site-packages (from matplotlib>=2.2.2->statannotations) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/arnaudon/base/lib/python3.10/site-packages (from pandas>=0.23.0->statannotations) (2022.2.1)\n",
"Requirement already satisfied: six>=1.5 in /Users/arnaudon/base/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib>=2.2.2->statannotations) (1.16.0)\n"
"Requirement already satisfied: statannotations in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (0.6.0)\n",
"Requirement already satisfied: matplotlib>=2.2.2 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from statannotations) (3.5.2)\n",
"Requirement already satisfied: pandas<2.0.0,>=0.23.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from statannotations) (1.5.3)\n",
"Requirement already satisfied: scipy>=1.1.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from statannotations) (1.9.3)\n",
"Requirement already satisfied: numpy>=1.12.1 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from statannotations) (1.23.3)\n",
"Requirement already satisfied: seaborn<0.12,>=0.9.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from statannotations) (0.11.2)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (4.25.0)\n",
"Requirement already satisfied: cycler>=0.10 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (0.11.0)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (3.0.9)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (21.3)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (1.4.2)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (2.8.2)\n",
"Requirement already satisfied: pillow>=6.2.0 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from matplotlib>=2.2.2->statannotations) (9.2.0)\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from pandas<2.0.0,>=0.23.0->statannotations) (2022.1)\n",
"Requirement already satisfied: six>=1.5 in /Users/gosztola/opt/anaconda3/envs/MARBLE/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib>=2.2.2->statannotations) (1.16.0)\n"
]
}
],
"source": [
"!pip install statannotations\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import pickle\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.pylab as pl\n",
"\n",
Expand All @@ -66,7 +53,7 @@
"\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.svm import SVC\n",
"from sklearn.preprocessing import StandardScaler\n"
"from sklearn.preprocessing import StandardScaler"
]
},
{
Expand All @@ -79,24 +66,32 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "e9bc422a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File ‘data/marble_embeddings_20ms_out20.pkl’ already there; not retrieving.\n",
"File ‘data/kinematics.pkl’ already there; not retrieving.\n"
"File ‘marble_embeddings_20ms_out20.pkl’ already there; not retrieving.\n",
"File ‘kinematics.pkl’ already there; not retrieving.\n"
]
}
],
"source": [
"# load MARBLE embeddings\n",
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062022 -O data/marble_embeddings_20ms_out20.pkl\n",
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969885 -O data/kinematics.pkl\n",
"with open('data/marble_embeddings_20ms_out20.pkl', 'rb') as handle:\n",
"# Marble trained with 3-dim output dimension\n",
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062022 -O marble_embeddings_20ms_out20.pkl\n",
"# Marble trained with 3-dim output dimension\n",
"#!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062023 -O marble_embeddings_20ms_out3.pkl\n",
"# Marble trained with 3-dim output dimension with rotation invariant features\n",
"#!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062087 -O marble_embeddings_20ms_out3_rotinv.pkl\n",
"\n",
"# kinematics\n",
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969885 -O kinematics.pkl\n",
"\n",
"with open('marble_embeddings_20ms_out20.pkl', 'rb') as handle:\n",
" data = pickle.load(handle)\n",
"\n",
"distance_matrices = data[0]\n",
Expand All @@ -107,7 +102,7 @@
"sample_inds = data[5]\n",
"\n",
"# load kinematic data\n",
"kinematic_data = pickle.load(open('data/kinematics.pkl','rb')) \n",
"kinematic_data = pickle.load(open('kinematics.pkl','rb')) \n",
"\n",
"# define conditions of movement\n",
"conditions=['DownLeft','Left','UpLeft','Up','UpRight','Right','DownRight'] \n"
Expand Down
Loading

0 comments on commit 36177ec

Please sign in to comment.