diff --git a/notebooks/readouts_tutorial.ipynb b/notebooks/readouts_tutorial.ipynb index 46683cf8..b4667052 100644 --- a/notebooks/readouts_tutorial.ipynb +++ b/notebooks/readouts_tutorial.ipynb @@ -40,16 +40,11 @@ "outputs": [], "source": [ "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", - "\n", + "import random\n", "import torch\n", - "import torch.nn as nn\n", "import numpy as np\n", - "from nnfabrik.builder import get_data\n", - "from nnfabrik.utility.nn_helpers import set_random_seed, get_dims_for_loader_dict\n", - "from neuralpredictors.utils import get_module_output\n", "\n", + "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "random_seed = 42" ] @@ -86,17 +81,11 @@ "metadata": {}, "outputs": [], "source": [ - "# in_shapes_dict = {\n", - "# k: get_module_output(core, v[in_name])[1:]\n", - "# for k, v in session_shape_dict.items()\n", - "# }\n", - "\n", "in_shapes_dict = {\n", " '21067-10-18': torch.Size([64, 144, 256]),\n", " '22846-10-16': torch.Size([64, 144, 256])\n", "}\n", "\n", - "# n_neurons_dict = {k: v[out_name][1] for k, v in session_shape_dict.items()}\n", "n_neurons_dict = {'21067-10-18': 8372, '22846-10-16': 7344}" ] }, @@ -276,7 +265,7 @@ "Let us import some data and define a sample core from which we will readout.\n", "You can download the data [here](https://gin.g-node.org/cajal/Sensorium2022/src/master/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip) and [here](https://gin.g-node.org/cajal/Sensorium2022/src/master/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip).\n", "\n", - "We will also need some utility functions from `nnfabrik` and `sensorium`, so you can run the next cell to install them if you have not done so already." + "We will also need some utility functions from `sensorium`, so you can run the next cell to install them if you have not done so already." ] }, { @@ -286,8 +275,14 @@ "outputs": [], "source": [ "%%capture \n", - "!pip install git+https://github.com/sinzlab/sensorium.git\n", - "!pip install nnfabrik" + "!pip install git+https://github.com/sinzlab/sensorium.git" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also need some helper function to load the data, which we define in the next cell:" ] }, { @@ -295,14 +290,63 @@ "execution_count": 10, "metadata": {}, "outputs": [], + "source": [ + "# The following are minimal adaptations of three utility functions found in nnfabrik that we need to initialise\n", + "# the core and readouts later on.\n", + "\n", + "def get_data(dataset_fn, dataset_config):\n", + " \"\"\"\n", + " See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/builder.py#L87\n", + " for the original implementation and documentation if you are interested.\n", + " \"\"\"\n", + " return dataset_fn(**dataset_config)\n", + "\n", + "def get_dims_for_loader_dict(dataloaders):\n", + " \"\"\"\n", + " See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L39\n", + " for the original implementation and docstring if you are interested.\n", + " \"\"\"\n", + " \n", + " def get_io_dims(data_loader):\n", + " items = next(iter(data_loader))\n", + " if hasattr(items, \"_asdict\"): # if it's a named tuple\n", + " items = items._asdict()\n", + "\n", + " if hasattr(items, \"items\"): # if dict like\n", + " return {k: v.shape for k, v in items.items()}\n", + " else:\n", + " return (v.shape for v in items)\n", + "\n", + " return {k: get_io_dims(v) for k, v in dataloaders.items()}\n", + "\n", + "\n", + "def set_random_seed(seed: int, deterministic: bool = True):\n", + " \"\"\"\n", + " See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L53\n", + " for the original implementation and docstring if you are intereseted.\n", + " \"\"\"\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " if deterministic:\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.manual_seed(seed) # this sets both CPU and CUDA seeds for PyTorch\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], "source": [ "## Load the data: you can modify this if you have stored it in another location\n", + "from sensorium.datasets import static_loaders\n", + "\n", "filenames = [\n", " '../../data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', \n", " '../../data/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip'\n", " ]\n", "\n", - "dataset_fn = 'sensorium.datasets.static_loaders'\n", "dataset_config = {'paths': filenames,\n", " 'normalize': True,\n", " 'include_behavior': False,\n", @@ -312,12 +356,12 @@ " 'cuda': True if device == 'cuda' else False,\n", " }\n", "\n", - "dataloaders = get_data(dataset_fn, dataset_config)" + "dataloaders = get_data(dataset_fn=static_loaders, dataset_config=dataset_config)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -342,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -383,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -413,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -443,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -495,7 +539,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -520,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -550,7 +594,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -578,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [ {