diff --git a/03_COSDD/solution.ipynb b/03_COSDD/solution.ipynb index 88d9dd1..bcf3fbc 100755 --- a/03_COSDD/solution.ipynb +++ b/03_COSDD/solution.ipynb @@ -25,12 +25,11 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", - "import logging\n", "\n", "import torch\n", "import tifffile\n", @@ -53,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -120,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "tags": [ "solution" @@ -359,20 +358,12 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Effective batch size: 16\n" - ] - } - ], - "source": [ - "real_batch_size = 16\n", - "n_grad_batches = 1\n", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "real_batch_size = 4\n", + "n_grad_batches = 4\n", "print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n", "crop_size = (256, 256)\n", "train_split = 0.9\n", @@ -472,9 +463,9 @@ "outputs": [], "source": [ "dimensions = ... ### Insert a value here\n", - "s_code_channels = 16\n", + "s_code_channels = 32\n", "\n", - "n_layers = 4\n", + "n_layers = 6\n", "z_dims = [s_code_channels // 2] * n_layers\n", "downsampling = [1] * n_layers\n", "lvae = LadderVAE(\n", @@ -492,8 +483,8 @@ " s_code_channels=s_code_channels,\n", " kernel_size=5,\n", " noise_direction=... ### Insert a value here\n", - " n_filters=16,\n", - " n_layers=3,\n", + " n_filters=32,\n", + " n_layers=4,\n", " n_gaussians=4,\n", " dimensions=dimensions,\n", ")\n", @@ -526,35 +517,24 @@ " data_mean=low_snr.mean(),\n", " data_std=low_snr.std(),\n", " n_grad_batches=n_grad_batches,\n", - " checkpointed=False,\n", + " checkpointed=True,\n", ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "tags": [ "solution" ] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n", - "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n", - "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n", - "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n" - ] - } - ], + "outputs": [], "source": [ "dimensions = 2 ### Insert a value here\n", - "s_code_channels = 16\n", + "s_code_channels = 32\n", "\n", - "n_layers = 4\n", + "n_layers = 6\n", "z_dims = [s_code_channels // 2] * n_layers\n", "downsampling = [1] * n_layers\n", "lvae = LadderVAE(\n", @@ -572,8 +552,8 @@ " s_code_channels=s_code_channels,\n", " kernel_size=5,\n", " noise_direction=\"x\", ### Insert a value here\n", - " n_filters=16,\n", - " n_layers=3,\n", + " n_filters=32,\n", + " n_layers=4,\n", " n_gaussians=4,\n", " dimensions=dimensions,\n", ")\n", @@ -606,7 +586,7 @@ " data_mean=low_snr.mean(),\n", " data_std=low_snr.std(),\n", " n_grad_batches=n_grad_batches,\n", - " checkpointed=False,\n", + " checkpointed=True,\n", ")" ] }, @@ -723,30 +703,18 @@ " max_time=max_time, # Remove this time limit to train the model fully\n", " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", - " precision=\"bf16-mixed\",\n", ")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "tags": [ "solution" ] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using bfloat16 Automatic Mixed Precision (AMP)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], + "outputs": [], "source": [ "model_name = \"mito-confocal\" ### Insert a value here\n", "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", @@ -764,7 +732,6 @@ " max_time=max_time, # Remove this time limit to train the model fully\n", " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", - " precision=\"bf16-mixed\",\n", ")" ] }, @@ -798,22 +765,12 @@ "# Exercise 2. Inference with COSDD" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "logger = logging.getLogger('pytorch_lightning')\n", - "logger.setLevel(logging.WARNING)" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1. Load test data\n", - "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference." + "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference." ] }, { @@ -823,7 +780,7 @@ "outputs": [], "source": [ "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", - "n_test_images = 2\n", + "n_test_images = 10\n", "# load the data\n", "test_set = tifffile.imread(lowsnr_path)\n", "test_set = test_set[:n_test_images, np.newaxis]\n", @@ -909,7 +866,6 @@ " enable_progress_bar=False,\n", " enable_checkpointing=False,\n", " logger=False,\n", - " precision=\"bf16-mixed\",\n", ")" ] }, @@ -932,7 +888,6 @@ " enable_progress_bar=False,\n", " enable_checkpointing=False,\n", " logger=False,\n", - " precision=\"bf16-mixed\",\n", ")" ] },