Skip to content

Commit

Permalink
Removed nnfabrik import and redefined utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fededagos committed Mar 11, 2024
1 parent b2c48d0 commit 2a3b7e9
Showing 1 changed file with 71 additions and 27 deletions.
98 changes: 71 additions & 27 deletions notebooks/readouts_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,11 @@
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
"\n",
"import random\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"from nnfabrik.builder import get_data\n",
"from nnfabrik.utility.nn_helpers import set_random_seed, get_dims_for_loader_dict\n",
"from neuralpredictors.utils import get_module_output\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"random_seed = 42"
]
Expand Down Expand Up @@ -86,17 +81,11 @@
"metadata": {},
"outputs": [],
"source": [
"# in_shapes_dict = {\n",
"# k: get_module_output(core, v[in_name])[1:]\n",
"# for k, v in session_shape_dict.items()\n",
"# }\n",
"\n",
"in_shapes_dict = {\n",
" '21067-10-18': torch.Size([64, 144, 256]),\n",
" '22846-10-16': torch.Size([64, 144, 256])\n",
"}\n",
"\n",
"# n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()}\n",
"n_neurons_dict = {'21067-10-18': 8372, '22846-10-16': 7344}"
]
},
Expand Down Expand Up @@ -276,7 +265,7 @@
"Let us import some data and define a sample core from which we will readout.\n",
"You can download the data [here](https://gin.g-node.org/cajal/Sensorium2022/src/master/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip) and [here](https://gin.g-node.org/cajal/Sensorium2022/src/master/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip).\n",
"\n",
"We will also need some utility functions from `nnfabrik` and `sensorium`, so you can run the next cell to install them if you have not done so already."
"We will also need some utility functions from `sensorium`, so you can run the next cell to install them if you have not done so already."
]
},
{
Expand All @@ -286,23 +275,78 @@
"outputs": [],
"source": [
"%%capture \n",
"!pip install git+https://github.com/sinzlab/sensorium.git\n",
"!pip install nnfabrik"
"!pip install git+https://github.com/sinzlab/sensorium.git"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will also need some helper function to load the data, which we define in the next cell:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# The following are minimal adaptations of three utility functions found in nnfabrik that we need to initialise\n",
"# the core and readouts later on.\n",
"\n",
"def get_data(dataset_fn, dataset_config):\n",
" \"\"\"\n",
" See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/builder.py#L87\n",
" for the original implementation and documentation if you are interested.\n",
" \"\"\"\n",
" return dataset_fn(**dataset_config)\n",
"\n",
"def get_dims_for_loader_dict(dataloaders):\n",
" \"\"\"\n",
" See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L39\n",
" for the original implementation and docstring if you are interested.\n",
" \"\"\"\n",
" \n",
" def get_io_dims(data_loader):\n",
" items = next(iter(data_loader))\n",
" if hasattr(items, \"_asdict\"): # if it's a named tuple\n",
" items = items._asdict()\n",
"\n",
" if hasattr(items, \"items\"): # if dict like\n",
" return {k: v.shape for k, v in items.items()}\n",
" else:\n",
" return (v.shape for v in items)\n",
"\n",
" return {k: get_io_dims(v) for k, v in dataloaders.items()}\n",
"\n",
"\n",
"def set_random_seed(seed: int, deterministic: bool = True):\n",
" \"\"\"\n",
" See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L53\n",
" for the original implementation and docstring if you are intereseted.\n",
" \"\"\"\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" if deterministic:\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.manual_seed(seed) # this sets both CPU and CUDA seeds for PyTorch\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"## Load the data: you can modify this if you have stored it in another location\n",
"from sensorium.datasets import static_loaders\n",
"\n",
"filenames = [\n",
" '../../data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', \n",
" '../../data/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip'\n",
" ]\n",
"\n",
"dataset_fn = 'sensorium.datasets.static_loaders'\n",
"dataset_config = {'paths': filenames,\n",
" 'normalize': True,\n",
" 'include_behavior': False,\n",
Expand All @@ -312,12 +356,12 @@
" 'cuda': True if device == 'cuda' else False,\n",
" }\n",
"\n",
"dataloaders = get_data(dataset_fn, dataset_config)"
"dataloaders = get_data(dataset_fn=static_loaders, dataset_config=dataset_config)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -342,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -383,7 +427,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -413,7 +457,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -443,7 +487,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -495,7 +539,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand All @@ -520,7 +564,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -550,7 +594,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -578,7 +622,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 2a3b7e9

Please sign in to comment.