diff --git a/03_COSDD/solution.ipynb b/03_COSDD/solution.ipynb
index 88d9dd1..bcf3fbc 100755
--- a/03_COSDD/solution.ipynb
+++ b/03_COSDD/solution.ipynb
@@ -25,12 +25,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "import os\n",
-    "import logging\n",
     "\n",
     "import torch\n",
     "import tifffile\n",
@@ -53,7 +52,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -120,7 +119,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {
     "tags": [
      "solution"
@@ -359,20 +358,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Effective batch size: 16\n"
-     ]
-    }
-   ],
-   "source": [
-    "real_batch_size = 16\n",
-    "n_grad_batches = 1\n",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "real_batch_size = 4\n",
+    "n_grad_batches = 4\n",
     "print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n",
     "crop_size = (256, 256)\n",
     "train_split = 0.9\n",
@@ -472,9 +463,9 @@
    "outputs": [],
    "source": [
     "dimensions = ... ### Insert a value here\n",
-    "s_code_channels = 16\n",
+    "s_code_channels = 32\n",
     "\n",
-    "n_layers = 4\n",
+    "n_layers = 6\n",
     "z_dims = [s_code_channels // 2] * n_layers\n",
     "downsampling = [1] * n_layers\n",
     "lvae = LadderVAE(\n",
@@ -492,8 +483,8 @@
     "    s_code_channels=s_code_channels,\n",
     "    kernel_size=5,\n",
     "    noise_direction=...  ### Insert a value here\n",
-    "    n_filters=16,\n",
-    "    n_layers=3,\n",
+    "    n_filters=32,\n",
+    "    n_layers=4,\n",
     "    n_gaussians=4,\n",
     "    dimensions=dimensions,\n",
     ")\n",
@@ -526,35 +517,24 @@
     "    data_mean=low_snr.mean(),\n",
     "    data_std=low_snr.std(),\n",
     "    n_grad_batches=n_grad_batches,\n",
-    "    checkpointed=False,\n",
+    "    checkpointed=True,\n",
     ")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {
     "tags": [
      "solution"
     ]
    },
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n",
-      "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n",
-      "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n",
-      "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "dimensions = 2 ### Insert a value here\n",
-    "s_code_channels = 16\n",
+    "s_code_channels = 32\n",
     "\n",
-    "n_layers = 4\n",
+    "n_layers = 6\n",
     "z_dims = [s_code_channels // 2] * n_layers\n",
     "downsampling = [1] * n_layers\n",
     "lvae = LadderVAE(\n",
@@ -572,8 +552,8 @@
     "    s_code_channels=s_code_channels,\n",
     "    kernel_size=5,\n",
     "    noise_direction=\"x\",  ### Insert a value here\n",
-    "    n_filters=16,\n",
-    "    n_layers=3,\n",
+    "    n_filters=32,\n",
+    "    n_layers=4,\n",
     "    n_gaussians=4,\n",
     "    dimensions=dimensions,\n",
     ")\n",
@@ -606,7 +586,7 @@
     "    data_mean=low_snr.mean(),\n",
     "    data_std=low_snr.std(),\n",
     "    n_grad_batches=n_grad_batches,\n",
-    "    checkpointed=False,\n",
+    "    checkpointed=True,\n",
     ")"
    ]
   },
@@ -723,30 +703,18 @@
     "    max_time=max_time,  # Remove this time limit to train the model fully\n",
     "    log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
     "    callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
-    "    precision=\"bf16-mixed\",\n",
     ")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": null,
    "metadata": {
     "tags": [
      "solution"
     ]
    },
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Using bfloat16 Automatic Mixed Precision (AMP)\n",
-      "GPU available: True (cuda), used: True\n",
-      "TPU available: False, using: 0 TPU cores\n",
-      "HPU available: False, using: 0 HPUs\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "model_name = \"mito-confocal\"  ### Insert a value here\n",
     "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n",
@@ -764,7 +732,6 @@
     "    max_time=max_time,  # Remove this time limit to train the model fully\n",
     "    log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
     "    callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
-    "    precision=\"bf16-mixed\",\n",
     ")"
    ]
   },
@@ -798,22 +765,12 @@
     "# Exercise 2. Inference with COSDD"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "logger = logging.getLogger('pytorch_lightning')\n",
-    "logger.setLevel(logging.WARNING)"
-   ]
-  },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
     "### 2.1. Load test data\n",
-    "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference."
+    "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference."
    ]
   },
   {
@@ -823,7 +780,7 @@
    "outputs": [],
    "source": [
     "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n",
-    "n_test_images = 2\n",
+    "n_test_images = 10\n",
     "# load the data\n",
     "test_set = tifffile.imread(lowsnr_path)\n",
     "test_set = test_set[:n_test_images, np.newaxis]\n",
@@ -909,7 +866,6 @@
     "    enable_progress_bar=False,\n",
     "    enable_checkpointing=False,\n",
     "    logger=False,\n",
-    "    precision=\"bf16-mixed\",\n",
     ")"
    ]
   },
@@ -932,7 +888,6 @@
     "    enable_progress_bar=False,\n",
     "    enable_checkpointing=False,\n",
     "    logger=False,\n",
-    "    precision=\"bf16-mixed\",\n",
     ")"
    ]
   },