From ae60130fa8caa5f77a0925382c4961b5d62cc0d9 Mon Sep 17 00:00:00 2001 From: Cateek Date: Thu, 22 Aug 2024 22:12:44 +0000 Subject: [PATCH] care updates --- 01_CARE/exercise.ipynb | 1027 +++++++++++++++++++++++++++++++++++++ 01_CARE/solution.ipynb | 1104 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2131 insertions(+) create mode 100644 01_CARE/exercise.ipynb create mode 100755 01_CARE/solution.ipynb diff --git a/01_CARE/exercise.ipynb b/01_CARE/exercise.ipynb new file mode 100644 index 0000000..a16f38b --- /dev/null +++ b/01_CARE/exercise.ipynb @@ -0,0 +1,1027 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Content-aware image restoration\n", + "\n", + "Fluorescence microscopy is constrained by the microscope's optics, fluorophore chemistry, and the sample's photon tolerance. These constraints require balancing imaging speed, resolution, light exposure, and depth. CARE demonstrates how Deep learning can extend the range of biological phenomena observable by microscopy when any of these factor becomes limiting.\n", + "\n", + "**Reference**: Weigert, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097. doi:[10.1038/s41592-018-0216-7](https://www.nature.com/articles/s41592-018-0216-7)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CARE\n", + "\n", + "In this first exercise we will train a CARE model for a 2D denoising task. CARE stands for Content-Aware image REstoration, and is a supervised method in which we use pairs of degraded and high quality image to train a particular task. The original paper demonstrated improvement of image quality on a variety of tasks such as image restoration or resolution improvement. Here, we will apply CARE to denoise images acquired at low laser power in order to recover the biological structures present in the data!\n", + "\n", + "

\n", + " \"Denoising \n", + "

\n", + "\n", + "We'll use the UNet model that we built in the semantic segmentation exercise and use a different set of functions to train the model for restoration rather than segmentation.\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Train a UNet on a new task!\n", + "- Understand how to train CARE\n", + " \n", + "
\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "\n", + "\n", + "import tifffile\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import Union, List, Tuple\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import torch.nn\n", + "import torch.optim\n", + "from torch import no_grad, cuda\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from datetime import datetime\n", + "from dlmbl_unet import UNet\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "
\n", + "\n", + "## Part 1: Set-up the data\n", + "\n", + "CARE is a fully supervised algorithm, therefore we need image pairs for training. In practice this is best achieved by acquiring each image twice, once with short exposure time or low laser power to obtain a noisy low-SNR (signal-to-noise ratio) image, and once with high SNR.\n", + "\n", + "Here, we will be using high SNR images of Human U2OS cells taken from the Broad Bioimage Benchmark Collection ([BBBC006v1](https://bbbc.broadinstitute.org/BBBC006)). The low SNR images were created by synthetically adding strong read-out and shot noise, and applying pixel binning of 2x2, thus mimicking acquisitions at a very low light level.\n", + "\n", + "Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so they would often need to be co-registered before training a CARE model." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split the dataset into training and validation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train\" / \"low\"\n", + "train_targets_path = root_path / \"train\" / \"GT\"\n", + "test_image_path = root_path / \"test\" / \"low\"\n", + "test_target_path = root_path / \"test\" / \"GT\"\n", + "\n", + "\n", + "image_files = list(Path(train_images_path).rglob(\"*.tif\"))\n", + "target_files = list(Path(train_targets_path).rglob(\"*.tif\"))\n", + "assert len(image_files) == len(\n", + " target_files\n", + "), \"Number of images and targets do not match\"\n", + "\n", + "print(f\"Total size of train dataset: {len(image_files)}\")\n", + "\n", + "# Split the train data into train and validation\n", + "seed = 42\n", + "train_files_percentage = 0.8\n", + "np.random.seed(seed)\n", + "shuffled_indices = np.random.permutation(len(image_files))\n", + "image_files = np.array(image_files)[shuffled_indices]\n", + "target_files = np.array(target_files)[shuffled_indices]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(image_files, target_files)]\n", + "), \"Files do not match\"\n", + "\n", + "train_image_files = image_files[: int(train_files_percentage * len(image_files))]\n", + "train_target_files = target_files[: int(train_files_percentage * len(target_files))]\n", + "val_image_files = image_files[int(train_files_percentage * len(image_files)) :]\n", + "val_target_files = target_files[int(train_files_percentage * len(target_files)) :]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(train_image_files, train_target_files)]\n", + "), \"Train files do not match\"\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(val_image_files, val_target_files)]\n", + "), \"Val files do not match\"\n", + "\n", + "print(f\"Train dataset size: {len(train_image_files)}\")\n", + "print(f\"Validation dataset size: {len(val_image_files)}\")\n", + "\n", + "# Read the test files\n", + "test_image_files = list(test_image_path.rglob(\"*.tif\"))\n", + "test_target_files = list(test_target_path.rglob(\"*.tif\"))\n", + "print(f\"Number of test files: {len(test_image_files)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Patching function\n", + "\n", + "In the majority of cases microscopy images are too large to be processed at once and need to be divided into smaller patches. We will define a function that takes image and target arrays and extract random (paired) patches from them.\n", + "\n", + "The method is a bit scary because accessing the whole patch coordinates requires some magical python expressions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_patches(\n", + " image_array: np.ndarray,\n", + " target_array: np.ndarray,\n", + " patch_size: Union[List[int], Tuple[int, ...]],\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"\n", + " Create random patches from an array and a target.\n", + "\n", + " The method calculates how many patches the image can be divided into and then\n", + " extracts an equal number of random patches.\n", + "\n", + " Important: the images should have an extra dimension before the spatial dimensions.\n", + " if you try it with only 2D or 3D images, don't forget to add an extra dimension\n", + " using `image = image[np.newaxis, ...]`\n", + " \"\"\"\n", + " # random generator\n", + " rng = np.random.default_rng()\n", + " image_patches = []\n", + " target_patches = []\n", + "\n", + " # iterate over the number of samples in the input array\n", + " for s in range(image_array.shape[0]):\n", + " # calculate the number of patches we can extract\n", + " sample = image_array[s]\n", + " target_sample = target_array[s]\n", + " n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)\n", + "\n", + " # iterate over the number of patches\n", + " for _ in range(n_patches):\n", + " # get random coordinates for the patch and create the crop coordinates\n", + " crop_coords = [\n", + " rng.integers(0, sample.shape[i] - patch_size[i], endpoint=True)\n", + " for i in range(len(patch_size))\n", + " ]\n", + "\n", + " # extract patch from the data\n", + " patch = (\n", + " sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # same for the target patch\n", + " target_patch = (\n", + " target_sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # add the patch pair to the list\n", + " image_patches.append(patch)\n", + " target_patches.append(target_patch)\n", + "\n", + " # return stack of patches\n", + " return np.stack(image_patches), np.stack(target_patches)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create patches\n", + "\n", + "To train the network, we will use patches of size 128x128. We first need to load the data, stack it and then call our patching function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images and stack them into arrays\n", + "train_images_array = np.stack([tifffile.imread(str(f)) for f in train_image_files])\n", + "train_targets_array = np.stack([tifffile.imread(str(f)) for f in train_target_files])\n", + "val_images_array = np.stack([tifffile.imread(str(f)) for f in val_image_files])\n", + "val_targets_array = np.stack([tifffile.imread(str(f)) for f in val_target_files])\n", + "\n", + "test_images_array = np.stack([tifffile.imread(str(f)) for f in test_image_files])\n", + "test_targets_array = np.stack([tifffile.imread(str(f)) for f in test_target_files])\n", + "\n", + "\n", + "print(f\"Train images array shape: {train_images_array.shape}\")\n", + "print(f\"Validation images array shape: {val_images_array.shape}\")\n", + "print(f\"Test array shape: {test_images_array.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create patches\n", + "patch_size = (128, 128)\n", + "\n", + "train_images_patches, train_targets_patches = create_patches(\n", + " train_images_array, train_targets_array, patch_size\n", + ")\n", + "assert (\n", + " train_images_patches.shape[0] == train_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "val_images_patches, val_targets_patches = create_patches(\n", + " val_images_array, val_targets_array, patch_size\n", + ")\n", + "assert (\n", + " val_images_patches.shape[0] == val_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "print(f\"Train images patches shape: {train_images_patches.shape}\")\n", + "print(f\"Validation images patches shape: {val_images_patches.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(train_images_patches[0], cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Train image\")\n", + "ax[0, 1].imshow(train_targets_patches[0], cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Train target\")\n", + "ax[1, 0].imshow(train_images_patches[1], cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Train image\")\n", + "ax[1, 1].imshow(train_targets_patches[1], cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Train target\")\n", + "ax[2, 0].imshow(train_images_patches[2], cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Train image\")\n", + "ax[2, 1].imshow(train_targets_patches[2], cmap=\"magma\")\n", + "ax[2, 1].set_title(\"Train target\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Dataset class\n", + "\n", + "In modern deep learning libraries, the data is often wrapped into a class called a `Dataset`. Instances of that class are then used to extract the patches before feeding them to the network.\n", + "\n", + "Here, the class will be wrapped around our pre-computed stacks of patches. Our `CAREDataset` class is built on top of the PyTorch `Dataset` class (we say it \"inherits\" from `Dataset`, the \"parent\" class). That means that it has some function hidden from us that are defined in the PyTorch repository, but that we also need to implement specific pre-defined methods, such as `__len__` and `__getitem__`. The advantage is that PyTorch knows what to do with a `Dataset` \"child\" class, since its behaviour is defined in the `Dataset` class, but we can do operations that are closely related to our own data in the method we implement." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Question: Normalization

\n", + "\n", + "In the following cell we calculate the mean and standard deviation of the input and target images so that we can normalize them.\n", + "Why is normalization important? \n", + "Should we normalize the input and ground truth data the same way? \n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Calculate the mean and std of the train dataset\n", + "train_mean = train_images_array.mean()\n", + "train_std = train_images_array.std()\n", + "target_mean = train_targets_array.mean()\n", + "target_std = train_targets_array.std()\n", + "print(f\"Train mean: {train_mean}, std: {train_std}\")\n", + "print(f\"Target mean: {target_mean}, std: {target_std}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These functions will be used to normalize the data and perform data augmentation as it is loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def normalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Normalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Normalized array.\n", + " \"\"\"\n", + " return (image - mean) / std\n", + "\n", + "\n", + "def _flip_and_rotate(\n", + " image: np.ndarray, rotate_state: int, flip_state: int\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Apply the given number of 90 degrees rotations and flip to an array.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " rotate_state : int\n", + " Number of 90 degree rotations to apply.\n", + " flip_state : int\n", + " 0 or 1, whether to flip the array or not.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Flipped and rotated array.\n", + " \"\"\"\n", + " rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))\n", + " flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated\n", + " return flipped.copy()\n", + "\n", + "\n", + "def augment_batch(\n", + " patch: np.ndarray,\n", + " target: np.ndarray,\n", + " seed: int = 42,\n", + ") -> Tuple[np.ndarray, ...]:\n", + " \"\"\"\n", + " Apply augmentation function to patches and masks.\n", + "\n", + " Parameters\n", + " ----------\n", + " patch : np.ndarray\n", + " Array containing single image or patch, 2D or 3D with masked pixels.\n", + " original_image : np.ndarray\n", + " Array containing original image or patch, 2D or 3D.\n", + " mask : np.ndarray\n", + " Array containing only masked pixels, 2D or 3D.\n", + " seed : int, optional\n", + " Seed for random number generator, controls the rotation and flipping.\n", + "\n", + " Returns\n", + " -------\n", + " Tuple[np.ndarray, ...]\n", + " Tuple of augmented arrays.\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed=seed)\n", + " rotate_state = rng.integers(0, 4)\n", + " flip_state = rng.integers(0, 2)\n", + " return (\n", + " _flip_and_rotate(patch, rotate_state, flip_state),\n", + " _flip_and_rotate(target, rotate_state, flip_state),\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Dataset\n", + "\n", + "Here we're defining the basic pytorch dataset class that will be used to load the data. This class will be used to load the data and apply the normalization and augmentation functions to the data as it is loaded.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Define a Dataset\n", + "class CAREDataset(Dataset): # CAREDataset inherits from the PyTorch Dataset class\n", + " def __init__(\n", + " self, image_data: np.ndarray, target_data: np.ndarray, apply_augmentations=False\n", + " ):\n", + " # these are the \"members\" of the CAREDataset\n", + " self.image_data = image_data\n", + " self.target_data = target_data\n", + " self.patch_augment = apply_augmentations\n", + "\n", + " def __len__(self):\n", + " \"\"\"Return the total number of patches.\n", + "\n", + " This method is called when applying `len(...)` to an instance of our class\n", + " \"\"\"\n", + " return self.image_data.shape[\n", + " 0\n", + " ] # Your code here, define the total number of patches\n", + "\n", + " def __getitem__(self, index):\n", + " \"\"\"Return a single pair of patches.\"\"\"\n", + "\n", + " # get patch\n", + " patch = self.image_data[index]\n", + "\n", + " # get target\n", + " target = self.target_data[index]\n", + "\n", + " # Apply transforms\n", + " if self.patch_augment:\n", + " patch, target = augment_batch(patch=patch, target=target)\n", + "\n", + " # Normalize the patch\n", + " patch = normalize(patch, train_mean, train_std)\n", + " target = normalize(target, target_mean, target_std)\n", + "\n", + " return patch[np.newaxis].astype(np.float32), target[np.newaxis].astype(\n", + " np.float32\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test the dataset\n", + "train_dataset = CAREDataset(\n", + " image_data=train_images_patches, target_data=train_targets_patches\n", + ")\n", + "val_dataset = CAREDataset(\n", + " image_data=val_images_patches, target_data=val_targets_patches\n", + ")\n", + "\n", + "# what is the dataset length?\n", + "assert len(train_dataset) == train_images_patches.shape[0], \"Dataset length is wrong\"\n", + "\n", + "# check the normalization\n", + "assert train_dataset[42][0].max() <= 10, \"Patch isn't normalized properly\"\n", + "assert train_dataset[42][1].max() <= 10, \"Target patch isn't normalized properly\"\n", + "\n", + "# check the get_item function\n", + "assert train_dataset[42][0].shape == (1, *patch_size), \"Patch size is wrong\"\n", + "assert train_dataset[42][1].shape == (1, *patch_size), \"Target patch size is wrong\"\n", + "assert train_dataset[42][0].dtype == np.float32, \"Patch dtype is wrong\"\n", + "assert train_dataset[42][1].dtype == np.float32, \"Target patch dtype is wrong\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training and validation data are stored as an instance of a `Dataset`. \n", + "This describes how each image should be loaded.\n", + "Now we will prepare them to be fed into the model with a `Dataloader`.\n", + "\n", + "This will use the Dataset to load individual images and organise them into batches.\n", + "The Dataloader will shuffle the data at the start of each epoch, outputting different random batches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the dataset and create a DataLoader\n", + "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Data

\n", + "\n", + "In this section, we prepared paired training data. \n", + "The steps were:\n", + "1) Loading the images.\n", + "2) Cropping them into patches.\n", + "3) Checking the patches visually.\n", + "4) Creating an instance of a pytorch dataset and dataloader.\n", + "\n", + "You'll see a similar preparation procedure followed for most deep learning vision tasks.\n", + "\n", + "Next, we'll use this data to train a denoising model.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 2: Training the model\n", + "\n", + "Image restoration task is very similar to the semantic segmentation task we have done in the previous exercise. We can use the same UNet model and just need to adapt a few things.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](nb_data/carenet.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate the model\n", + "\n", + "We'll be using the model from the previous exercise, so we need to load the relevant module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model\n", + "model = UNet(depth=3, in_channels=1, out_channels=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 1: Loss function

\n", + "\n", + "CARE trains image to image, therefore we need a different loss function compared to the segmentation task (image to mask). Can you think of a suitable loss function?\n", + "\n", + "*hint: look in the `torch.nn` module of PyTorch ([link](https://pytorch.org/docs/stable/nn.html#loss-functions)).*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "loss = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 2: Optimizer

\n", + "\n", + "Similarly, define the optimizer. No need to be too inventive here!\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "optimizer = #### YOUR CODE HERE ####" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Here we will train a CARE model using classes and functions you defined in the previous tasks.\n", + "We're using the same training loop as in the semantic segmentation exercise.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%tensorboard --logdir runs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "n_epochs = 10\n", + "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", + "# tensorboard\n", + "tb_logger = SummaryWriter(\"runs/Unet\"+datetime.now().strftime('%d%H-%M%S'))\n", + "\n", + "train_losses = []\n", + "val_losses = []\n", + "\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " for i, (image_batch, target_batch) in enumerate(train_dataloader):\n", + " batch = image_batch.to(device)\n", + " target = target_batch.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = model(batch)\n", + " train_loss = loss(output, target)\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " if i % 10 == 0:\n", + " print(f\"Epoch: {epoch}, Batch: {i}, Loss: {train_loss.item()}\")\n", + "\n", + " model.eval()\n", + "\n", + " with no_grad():\n", + " val_loss = 0\n", + " for i, (batch, target) in enumerate(val_dataloader):\n", + " batch = batch.to(device)\n", + " target = target.to(device)\n", + "\n", + " output = model(batch)\n", + " val_loss = loss(output, target)\n", + "\n", + " # log tensorboard\n", + " step = epoch * len(train_dataloader)\n", + " tb_logger.add_scalar(tag=\"train_loss\", scalar_value=train_loss, global_step=step)\n", + " tb_logger.add_scalar(tag=\"val_loss\", scalar_value=val_loss, global_step=step)\n", + "\n", + " # we always log the last validation images\n", + " tb_logger.add_images(tag=\"val_input\", img_tensor=batch.to(\"cpu\"), global_step=step)\n", + " tb_logger.add_images(tag=\"val_target\", img_tensor=target.to(\"cpu\"), global_step=step)\n", + " tb_logger.add_images(\n", + " tag=\"val_prediction\", img_tensor=output.to(\"cpu\"), global_step=step\n", + " )\n", + "\n", + " print(f\"Validation loss: {val_loss.item()}\")\n", + "\n", + " # Save the losses for plotting\n", + " train_losses.append(train_loss.item())\n", + " val_losses.append(val_loss.item())\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot the loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot training and validation losses\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(train_losses)\n", + "plt.plot(val_losses)\n", + "plt.xlabel(\"Iterations\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.legend([\"Train loss\", \"Validation loss\"])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 2: Training

\n", + "\n", + "In this section, we created and trained a UNet for denoising.\n", + "We:\n", + "1) Instantiated the model with random weights.\n", + "2) Chose a loss function to compare the output image to the ground truth clean image.\n", + "3) Chose an optimizer to minimize that loss function.\n", + "4) Trained the model with this optimizer.\n", + "5) Examined the training and validation loss curves to see how well our model trained.\n", + "\n", + "Next, we'll load a test set of noisy images and see how well our model denoises them.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 3: Predicting on the test dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the dataset for the test data\n", + "test_dataset = CAREDataset(\n", + " image_data=test_images_array, target_data=test_targets_array\n", + ")\n", + "test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 3: Predict using the correct mean/std

\n", + "\n", + "In Part 1 we normalized the inputs and the targets before feeding them into the model. This means that the model will output normalized clean images, but we'd like them to be on the same scale as the real clean images.\n", + "\n", + "Recall the variables we used to normalize the data in Part 1, and use them denormalize the output of the model.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Denormalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Denormalized array.\n", + " \"\"\"\n", + " return image * std + mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), #### YOUR CODE HERE ####))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(test_images_array[0].squeeze(), cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Test image\")\n", + "ax[0, 1].imshow(predictions[0][0].squeeze(), cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Prediction\")\n", + "ax[1, 0].imshow(test_images_array[1].squeeze(), cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Test image\")\n", + "ax[1, 1].imshow(predictions[1][0].squeeze(), cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Prediction\")\n", + "ax[2, 0].imshow(test_images_array[2].squeeze(), cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Test image\")\n", + "ax[2, 1].imshow(predictions[2][0].squeeze(), cmap=\"magma\")\n", + "plt.tight_layout()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Predicting

\n", + "\n", + "In this section, we evaluated the performance of our denoiser.\n", + "We:\n", + "1) Created a CAREDataset and Dataloader for a prediction loop.\n", + "2) Ran a prediction loop on the test data.\n", + "3) Examined the outputs.\n", + "\n", + "This notebook has shown how matched pairs of noisy and clean images can train a UNet to denoise, but what if we don't have any clean images? In the next notebook, we'll try Noise2Void, a method for training a UNet to denoise with only noisy images.\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/01_CARE/solution.ipynb b/01_CARE/solution.ipynb new file mode 100755 index 0000000..fcff2a3 --- /dev/null +++ b/01_CARE/solution.ipynb @@ -0,0 +1,1104 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Content-aware image restoration\n", + "\n", + "Fluorescence microscopy is constrained by the microscope's optics, fluorophore chemistry, and the sample's photon tolerance. These constraints require balancing imaging speed, resolution, light exposure, and depth. CARE demonstrates how Deep learning can extend the range of biological phenomena observable by microscopy when any of these factor becomes limiting.\n", + "\n", + "**Reference**: Weigert, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097. doi:[10.1038/s41592-018-0216-7](https://www.nature.com/articles/s41592-018-0216-7)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CARE\n", + "\n", + "In this first exercise we will train a CARE model for a 2D denoising task. CARE stands for Content-Aware image REstoration, and is a supervised method in which we use pairs of degraded and high quality image to train a particular task. The original paper demonstrated improvement of image quality on a variety of tasks such as image restoration or resolution improvement. Here, we will apply CARE to denoise images acquired at low laser power in order to recover the biological structures present in the data!\n", + "\n", + "

\n", + " \"Denoising \n", + "

\n", + "\n", + "We'll use the UNet model that we built in the semantic segmentation exercise and use a different set of functions to train the model for restoration rather than segmentation.\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Train a UNet on a new task!\n", + "- Understand how to train CARE\n", + " \n", + "
\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "\n", + "\n", + "import tifffile\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import Union, List, Tuple\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import torch.nn\n", + "import torch.optim\n", + "from torch import no_grad, cuda\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from datetime import datetime\n", + "from dlmbl_unet import UNet\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "
\n", + "\n", + "## Part 1: Set-up the data\n", + "\n", + "CARE is a fully supervised algorithm, therefore we need image pairs for training. In practice this is best achieved by acquiring each image twice, once with short exposure time or low laser power to obtain a noisy low-SNR (signal-to-noise ratio) image, and once with high SNR.\n", + "\n", + "Here, we will be using high SNR images of Human U2OS cells taken from the Broad Bioimage Benchmark Collection ([BBBC006v1](https://bbbc.broadinstitute.org/BBBC006)). The low SNR images were created by synthetically adding strong read-out and shot noise, and applying pixel binning of 2x2, thus mimicking acquisitions at a very low light level.\n", + "\n", + "Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so they would often need to be co-registered before training a CARE model." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split the dataset into training and validation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train\" / \"low\"\n", + "train_targets_path = root_path / \"train\" / \"GT\"\n", + "test_image_path = root_path / \"test\" / \"low\"\n", + "test_target_path = root_path / \"test\" / \"GT\"\n", + "\n", + "\n", + "image_files = list(Path(train_images_path).rglob(\"*.tif\"))\n", + "target_files = list(Path(train_targets_path).rglob(\"*.tif\"))\n", + "assert len(image_files) == len(\n", + " target_files\n", + "), \"Number of images and targets do not match\"\n", + "\n", + "print(f\"Total size of train dataset: {len(image_files)}\")\n", + "\n", + "# Split the train data into train and validation\n", + "seed = 42\n", + "train_files_percentage = 0.8\n", + "np.random.seed(seed)\n", + "shuffled_indices = np.random.permutation(len(image_files))\n", + "image_files = np.array(image_files)[shuffled_indices]\n", + "target_files = np.array(target_files)[shuffled_indices]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(image_files, target_files)]\n", + "), \"Files do not match\"\n", + "\n", + "train_image_files = image_files[: int(train_files_percentage * len(image_files))]\n", + "train_target_files = target_files[: int(train_files_percentage * len(target_files))]\n", + "val_image_files = image_files[int(train_files_percentage * len(image_files)) :]\n", + "val_target_files = target_files[int(train_files_percentage * len(target_files)) :]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(train_image_files, train_target_files)]\n", + "), \"Train files do not match\"\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(val_image_files, val_target_files)]\n", + "), \"Val files do not match\"\n", + "\n", + "print(f\"Train dataset size: {len(train_image_files)}\")\n", + "print(f\"Validation dataset size: {len(val_image_files)}\")\n", + "\n", + "# Read the test files\n", + "test_image_files = list(test_image_path.rglob(\"*.tif\"))\n", + "test_target_files = list(test_target_path.rglob(\"*.tif\"))\n", + "print(f\"Number of test files: {len(test_image_files)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Patching function\n", + "\n", + "In the majority of cases microscopy images are too large to be processed at once and need to be divided into smaller patches. We will define a function that takes image and target arrays and extract random (paired) patches from them.\n", + "\n", + "The method is a bit scary because accessing the whole patch coordinates requires some magical python expressions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_patches(\n", + " image_array: np.ndarray,\n", + " target_array: np.ndarray,\n", + " patch_size: Union[List[int], Tuple[int, ...]],\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"\n", + " Create random patches from an array and a target.\n", + "\n", + " The method calculates how many patches the image can be divided into and then\n", + " extracts an equal number of random patches.\n", + "\n", + " Important: the images should have an extra dimension before the spatial dimensions.\n", + " if you try it with only 2D or 3D images, don't forget to add an extra dimension\n", + " using `image = image[np.newaxis, ...]`\n", + " \"\"\"\n", + " # random generator\n", + " rng = np.random.default_rng()\n", + " image_patches = []\n", + " target_patches = []\n", + "\n", + " # iterate over the number of samples in the input array\n", + " for s in range(image_array.shape[0]):\n", + " # calculate the number of patches we can extract\n", + " sample = image_array[s]\n", + " target_sample = target_array[s]\n", + " n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)\n", + "\n", + " # iterate over the number of patches\n", + " for _ in range(n_patches):\n", + " # get random coordinates for the patch and create the crop coordinates\n", + " crop_coords = [\n", + " rng.integers(0, sample.shape[i] - patch_size[i], endpoint=True)\n", + " for i in range(len(patch_size))\n", + " ]\n", + "\n", + " # extract patch from the data\n", + " patch = (\n", + " sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # same for the target patch\n", + " target_patch = (\n", + " target_sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # add the patch pair to the list\n", + " image_patches.append(patch)\n", + " target_patches.append(target_patch)\n", + "\n", + " # return stack of patches\n", + " return np.stack(image_patches), np.stack(target_patches)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create patches\n", + "\n", + "To train the network, we will use patches of size 128x128. We first need to load the data, stack it and then call our patching function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images and stack them into arrays\n", + "train_images_array = np.stack([tifffile.imread(str(f)) for f in train_image_files])\n", + "train_targets_array = np.stack([tifffile.imread(str(f)) for f in train_target_files])\n", + "val_images_array = np.stack([tifffile.imread(str(f)) for f in val_image_files])\n", + "val_targets_array = np.stack([tifffile.imread(str(f)) for f in val_target_files])\n", + "\n", + "test_images_array = np.stack([tifffile.imread(str(f)) for f in test_image_files])\n", + "test_targets_array = np.stack([tifffile.imread(str(f)) for f in test_target_files])\n", + "\n", + "\n", + "print(f\"Train images array shape: {train_images_array.shape}\")\n", + "print(f\"Validation images array shape: {val_images_array.shape}\")\n", + "print(f\"Test array shape: {test_images_array.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create patches\n", + "patch_size = (128, 128)\n", + "\n", + "train_images_patches, train_targets_patches = create_patches(\n", + " train_images_array, train_targets_array, patch_size\n", + ")\n", + "assert (\n", + " train_images_patches.shape[0] == train_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "val_images_patches, val_targets_patches = create_patches(\n", + " val_images_array, val_targets_array, patch_size\n", + ")\n", + "assert (\n", + " val_images_patches.shape[0] == val_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "print(f\"Train images patches shape: {train_images_patches.shape}\")\n", + "print(f\"Validation images patches shape: {val_images_patches.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(train_images_patches[0], cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Train image\")\n", + "ax[0, 1].imshow(train_targets_patches[0], cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Train target\")\n", + "ax[1, 0].imshow(train_images_patches[1], cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Train image\")\n", + "ax[1, 1].imshow(train_targets_patches[1], cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Train target\")\n", + "ax[2, 0].imshow(train_images_patches[2], cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Train image\")\n", + "ax[2, 1].imshow(train_targets_patches[2], cmap=\"magma\")\n", + "ax[2, 1].set_title(\"Train target\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Dataset class\n", + "\n", + "In modern deep learning libraries, the data is often wrapped into a class called a `Dataset`. Instances of that class are then used to extract the patches before feeding them to the network.\n", + "\n", + "Here, the class will be wrapped around our pre-computed stacks of patches. Our `CAREDataset` class is built on top of the PyTorch `Dataset` class (we say it \"inherits\" from `Dataset`, the \"parent\" class). That means that it has some function hidden from us that are defined in the PyTorch repository, but that we also need to implement specific pre-defined methods, such as `__len__` and `__getitem__`. The advantage is that PyTorch knows what to do with a `Dataset` \"child\" class, since its behaviour is defined in the `Dataset` class, but we can do operations that are closely related to our own data in the method we implement." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Question: Normalization

\n", + "\n", + "In the following cell we calculate the mean and standard deviation of the input and target images so that we can normalize them.\n", + "Why is normalization important? \n", + "Should we normalize the input and ground truth data the same way? \n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "solution" + ] + }, + "source": [ + "Normalization brings the data's values into a standardized range, making the magnitude of gradients suitable for the default learning rate. \n", + "The target noise-free images have a much higher intensity than the noisy input images.\n", + "They need to be normalized using their own statistics to bring them into the same range." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Calculate the mean and std of the train dataset\n", + "train_mean = train_images_array.mean()\n", + "train_std = train_images_array.std()\n", + "target_mean = train_targets_array.mean()\n", + "target_std = train_targets_array.std()\n", + "print(f\"Train mean: {train_mean}, std: {train_std}\")\n", + "print(f\"Target mean: {target_mean}, std: {target_std}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These functions will be used to normalize the data and perform data augmentation as it is loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def normalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Normalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Normalized array.\n", + " \"\"\"\n", + " return (image - mean) / std\n", + "\n", + "\n", + "def _flip_and_rotate(\n", + " image: np.ndarray, rotate_state: int, flip_state: int\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Apply the given number of 90 degrees rotations and flip to an array.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " rotate_state : int\n", + " Number of 90 degree rotations to apply.\n", + " flip_state : int\n", + " 0 or 1, whether to flip the array or not.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Flipped and rotated array.\n", + " \"\"\"\n", + " rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))\n", + " flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated\n", + " return flipped.copy()\n", + "\n", + "\n", + "def augment_batch(\n", + " patch: np.ndarray,\n", + " target: np.ndarray,\n", + " seed: int = 42,\n", + ") -> Tuple[np.ndarray, ...]:\n", + " \"\"\"\n", + " Apply augmentation function to patches and masks.\n", + "\n", + " Parameters\n", + " ----------\n", + " patch : np.ndarray\n", + " Array containing single image or patch, 2D or 3D with masked pixels.\n", + " original_image : np.ndarray\n", + " Array containing original image or patch, 2D or 3D.\n", + " mask : np.ndarray\n", + " Array containing only masked pixels, 2D or 3D.\n", + " seed : int, optional\n", + " Seed for random number generator, controls the rotation and flipping.\n", + "\n", + " Returns\n", + " -------\n", + " Tuple[np.ndarray, ...]\n", + " Tuple of augmented arrays.\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed=seed)\n", + " rotate_state = rng.integers(0, 4)\n", + " flip_state = rng.integers(0, 2)\n", + " return (\n", + " _flip_and_rotate(patch, rotate_state, flip_state),\n", + " _flip_and_rotate(target, rotate_state, flip_state),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Dataset\n", + "\n", + "Here we're defining the basic pytorch dataset class that will be used to load the data. This class will be used to load the data and apply the normalization and augmentation functions to the data as it is loaded.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Define a Dataset\n", + "class CAREDataset(Dataset): # CAREDataset inherits from the PyTorch Dataset class\n", + " def __init__(\n", + " self, image_data: np.ndarray, target_data: np.ndarray, apply_augmentations=False\n", + " ):\n", + " # these are the \"members\" of the CAREDataset\n", + " self.image_data = image_data\n", + " self.target_data = target_data\n", + " self.patch_augment = apply_augmentations\n", + "\n", + " def __len__(self):\n", + " \"\"\"Return the total number of patches.\n", + "\n", + " This method is called when applying `len(...)` to an instance of our class\n", + " \"\"\"\n", + " return self.image_data.shape[\n", + " 0\n", + " ] # Your code here, define the total number of patches\n", + "\n", + " def __getitem__(self, index):\n", + " \"\"\"Return a single pair of patches.\"\"\"\n", + "\n", + " # get patch\n", + " patch = self.image_data[index]\n", + "\n", + " # get target\n", + " target = self.target_data[index]\n", + "\n", + " # Apply transforms\n", + " if self.patch_augment:\n", + " patch, target = augment_batch(patch=patch, target=target)\n", + "\n", + " # Normalize the patch\n", + " patch = normalize(patch, train_mean, train_std)\n", + " target = normalize(target, target_mean, target_std)\n", + "\n", + " return patch[np.newaxis].astype(np.float32), target[np.newaxis].astype(\n", + " np.float32\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test the dataset\n", + "train_dataset = CAREDataset(\n", + " image_data=train_images_patches, target_data=train_targets_patches\n", + ")\n", + "val_dataset = CAREDataset(\n", + " image_data=val_images_patches, target_data=val_targets_patches\n", + ")\n", + "\n", + "# what is the dataset length?\n", + "assert len(train_dataset) == train_images_patches.shape[0], \"Dataset length is wrong\"\n", + "\n", + "# check the normalization\n", + "assert train_dataset[42][0].max() <= 10, \"Patch isn't normalized properly\"\n", + "assert train_dataset[42][1].max() <= 10, \"Target patch isn't normalized properly\"\n", + "\n", + "# check the get_item function\n", + "assert train_dataset[42][0].shape == (1, *patch_size), \"Patch size is wrong\"\n", + "assert train_dataset[42][1].shape == (1, *patch_size), \"Target patch size is wrong\"\n", + "assert train_dataset[42][0].dtype == np.float32, \"Patch dtype is wrong\"\n", + "assert train_dataset[42][1].dtype == np.float32, \"Target patch dtype is wrong\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training and validation data are stored as an instance of a `Dataset`. \n", + "This describes how each image should be loaded.\n", + "Now we will prepare them to be fed into the model with a `Dataloader`.\n", + "\n", + "This will use the Dataset to load individual images and organise them into batches.\n", + "The Dataloader will shuffle the data at the start of each epoch, outputting different random batches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the dataset and create a DataLoader\n", + "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Data

\n", + "\n", + "In this section, we prepared paired training data. \n", + "The steps were:\n", + "1) Loading the images.\n", + "2) Cropping them into patches.\n", + "3) Checking the patches visually.\n", + "4) Creating an instance of a pytorch dataset and dataloader.\n", + "\n", + "You'll see a similar preparation procedure followed for most deep learning vision tasks.\n", + "\n", + "Next, we'll use this data to train a denoising model.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 2: Training the model\n", + "\n", + "Image restoration task is very similar to the semantic segmentation task we have done in the previous exercise. We can use the same UNet model and just need to adapt a few things.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](nb_data/carenet.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate the model\n", + "\n", + "We'll be using the model from the previous exercise, so we need to load the relevant module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model\n", + "model = UNet(depth=3, in_channels=1, out_channels=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 1: Loss function

\n", + "\n", + "CARE trains image to image, therefore we need a different loss function compared to the segmentation task (image to mask). Can you think of a suitable loss function?\n", + "\n", + "*hint: look in the `torch.nn` module of PyTorch ([link](https://pytorch.org/docs/stable/nn.html#loss-functions)).*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "loss = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "loss = torch.nn.MSELoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 2: Optimizer

\n", + "\n", + "Similarly, define the optimizer. No need to be too inventive here!\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "optimizer = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(\n", + " model.parameters(), lr=1e-4\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Here we will train a CARE model using classes and functions you defined in the previous tasks.\n", + "We're using the same training loop as in the semantic segmentation exercise.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%tensorboard --logdir runs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "n_epochs = 10\n", + "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", + "# tensorboard\n", + "tb_logger = SummaryWriter(\"runs/Unet\"+datetime.now().strftime('%d%H-%M%S'))\n", + "\n", + "train_losses = []\n", + "val_losses = []\n", + "\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " for i, (image_batch, target_batch) in enumerate(train_dataloader):\n", + " batch = image_batch.to(device)\n", + " target = target_batch.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = model(batch)\n", + " train_loss = loss(output, target)\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " if i % 10 == 0:\n", + " print(f\"Epoch: {epoch}, Batch: {i}, Loss: {train_loss.item()}\")\n", + "\n", + " model.eval()\n", + "\n", + " with no_grad():\n", + " val_loss = 0\n", + " for i, (batch, target) in enumerate(val_dataloader):\n", + " batch = batch.to(device)\n", + " target = target.to(device)\n", + "\n", + " output = model(batch)\n", + " val_loss = loss(output, target)\n", + "\n", + " # log tensorboard\n", + " step = epoch * len(train_dataloader)\n", + " tb_logger.add_scalar(tag=\"train_loss\", scalar_value=train_loss, global_step=step)\n", + " tb_logger.add_scalar(tag=\"val_loss\", scalar_value=val_loss, global_step=step)\n", + "\n", + " # we always log the last validation images\n", + " tb_logger.add_images(tag=\"val_input\", img_tensor=batch.to(\"cpu\"), global_step=step)\n", + " tb_logger.add_images(tag=\"val_target\", img_tensor=target.to(\"cpu\"), global_step=step)\n", + " tb_logger.add_images(\n", + " tag=\"val_prediction\", img_tensor=output.to(\"cpu\"), global_step=step\n", + " )\n", + "\n", + " print(f\"Validation loss: {val_loss.item()}\")\n", + "\n", + " # Save the losses for plotting\n", + " train_losses.append(train_loss.item())\n", + " val_losses.append(val_loss.item())\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot the loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot training and validation losses\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(train_losses)\n", + "plt.plot(val_losses)\n", + "plt.xlabel(\"Iterations\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.legend([\"Train loss\", \"Validation loss\"])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 2: Training

\n", + "\n", + "In this section, we created and trained a UNet for denoising.\n", + "We:\n", + "1) Instantiated the model with random weights.\n", + "2) Chose a loss function to compare the output image to the ground truth clean image.\n", + "3) Chose an optimizer to minimize that loss function.\n", + "4) Trained the model with this optimizer.\n", + "5) Examined the training and validation loss curves to see how well our model trained.\n", + "\n", + "Next, we'll load a test set of noisy images and see how well our model denoises them.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 3: Predicting on the test dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the dataset for the test data\n", + "test_dataset = CAREDataset(\n", + " image_data=test_images_array, target_data=test_targets_array\n", + ")\n", + "test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 3: Predict using the correct mean/std

\n", + "\n", + "In Part 1 we normalized the inputs and the targets before feeding them into the model. This means that the model will output normalized clean images, but we'd like them to be on the same scale as the real clean images.\n", + "\n", + "Recall the variables we used to normalize the data in Part 1, and use them denormalize the output of the model.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Denormalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Denormalized array.\n", + " \"\"\"\n", + " return image * std + mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), #### YOUR CODE HERE ####))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), train_mean, train_std))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(test_images_array[0].squeeze(), cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Test image\")\n", + "ax[0, 1].imshow(predictions[0][0].squeeze(), cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Prediction\")\n", + "ax[1, 0].imshow(test_images_array[1].squeeze(), cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Test image\")\n", + "ax[1, 1].imshow(predictions[1][0].squeeze(), cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Prediction\")\n", + "ax[2, 0].imshow(test_images_array[2].squeeze(), cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Test image\")\n", + "ax[2, 1].imshow(predictions[2][0].squeeze(), cmap=\"magma\")\n", + "plt.tight_layout()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Predicting

\n", + "\n", + "In this section, we evaluated the performance of our denoiser.\n", + "We:\n", + "1) Created a CAREDataset and Dataloader for a prediction loop.\n", + "2) Ran a prediction loop on the test data.\n", + "3) Examined the outputs.\n", + "\n", + "This notebook has shown how matched pairs of noisy and clean images can train a UNet to denoise, but what if we don't have any clean images? In the next notebook, we'll try Noise2Void, a method for training a UNet to denoise with only noisy images.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}