diff --git a/.github/workflows/build-notebooks.yaml b/.github/workflows/build-notebooks.yaml new file mode 100644 index 0000000..c68235e --- /dev/null +++ b/.github/workflows/build-notebooks.yaml @@ -0,0 +1,32 @@ +name: Build Notebooks +on: + push: + +jobs: + run: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install jupytext nbconvert + + + - name: Build notebooks + run: | + jupytext --to ipynb --update-metadata '{"jupytext":{"cell_metadata_filter":"all"}}' solution.py + + jupyter nbconvert solution.ipynb --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook --output exercise.ipynb + jupyter nbconvert solution.ipynb --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags task --to notebook --output solution.ipynb + + - uses: EndBug/add-and-commit@v9 + with: + add: solution.ipynb exercise.ipynb \ No newline at end of file diff --git a/README.md b/README.md index c548c5e..c878975 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,28 @@ -# Exercise 9: Explainable AI and Knowledge Extraction +# Exercise 8: Explainable AI and Knowledge Extraction + +## Overview +The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. + +We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. +Unlike regular MNIST, our dataset is classified not by number, but by color! The question is... which colors fall within which class? + +![CMNIST](assets/cmnist.png) + +In this exercise, we will return to conventional, gradient-based attribution methods to see what they can tell us about what the classifier knows. +We will see that, even for such a simple problem, there is some information that these methods do not give us. + +We will then train a generative adversarial network, or GAN, to try to create counterfactual images. +These images are modifications of the originals, which are able to fool the classifier into thinking they come from a different class!. +We will evaluate this GAN using our classifier; Is it really able to change an image's class in a meaningful way? + +Finally, we will combine the two methods — attribution and counterfactual — to get a full explanation of what exactly it is that the classifier is doing. We will likely learn whether it can teach us anything, and whether we should trust it! ## Setup -Before anything else, in the super-repository called `DL-MBL-2023`: +Before anything else, in the super-repository called `DL-MBL-2024`: ``` git pull -git submodule update --init 09_knowledge_extraction +git submodule update --init 08_knowledge_extraction ``` Then, if you have any other exercises still running, please save your progress and shut down those kernels. @@ -13,25 +30,17 @@ This is a GPU-hungry exercise so you're going to need all the GPU memory you can Next, run the setup script. It might take a few minutes. ``` -cd 09_knowledge_extraction -source setup.sh +cd 08_knowledge_extraction +sh setup.sh ``` This will: -- Create a `mamba` environment for this exercise -- Download and unzip data and pre-trained network +- Create a `conda` environment for this exercise +- Download the data and train the classifier we're learning about Feel free to have a look at the `setup.sh` script to see the details. -Next, begin a Jupyter Lab instance: -``` -jupyter lab -``` -...and continue with the instructions in the notebook. +Next, open the exercise notebook! -## Overview +### Acknowledgments -In this exercise we will: -1. Train a classifier to predict, from 2D EM images of synapses, which neurotransmitter is (mostly) used at that synapse -2. Use a gradient-based attribution method to try to find out what parts of the images contribute to the prediction -3. Train a CycleGAN to create counterfactual images -4. Run a discriminative attribution from counterfactuals +This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein. \ No newline at end of file diff --git a/assets/cmnist.png b/assets/cmnist.png new file mode 100644 index 0000000..a56d461 Binary files /dev/null and b/assets/cmnist.png differ diff --git a/assets/same_class_diff_color.png b/assets/same_class_diff_color.png new file mode 100644 index 0000000..c5d98c8 Binary files /dev/null and b/assets/same_class_diff_color.png differ diff --git a/assets/same_color_diff_class.png b/assets/same_color_diff_class.png new file mode 100644 index 0000000..775ce42 Binary files /dev/null and b/assets/same_color_diff_class.png differ diff --git a/assets/stargan.png b/assets/stargan.png new file mode 100644 index 0000000..0695b0d Binary files /dev/null and b/assets/stargan.png differ diff --git a/create_environment.sh b/create_environment.sh new file mode 100644 index 0000000..d97a408 --- /dev/null +++ b/create_environment.sh @@ -0,0 +1,5 @@ +# Contains the steps that I used to create the environment, for memory +mamba create -n 08_knowledge_extraction python=3.11 pytorch torchvision pytorch-cuda=12.1 -c conda-forge -c pytorch -c nvidia +mamba activate 08_knowledge_extraction +pip install -r requirements.txt +mamba env export > environment.yaml diff --git a/dac/__init__.py b/dac/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dac/activations.py b/dac/activations.py deleted file mode 100644 index e9e9f81..0000000 --- a/dac/activations.py +++ /dev/null @@ -1,72 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from functools import partial -import cv2 - -from dac.utils import image_to_tensor - -def get_layer_names(net): - layer_names = util.get_model_layers(net, False) - # print(layer_names) - -def save_activation(activations, name, mod, inp, out): - activations[name].append(out.cpu()) - -def get_activation_dict(net, images, activations): - """ - net: The NN object - images: list of 2D (h,w) normalized image arrays. - """ - tensor_images = [] - for im in images: - tensor_images.append(image_to_tensor(im)) - - # Registering hooks for all the Conv2d layers - # Note: Hooks are called EVERY TIME the module performs a forward pass. For modules that are - # called repeatedly at different stages of the forward pass (like RELUs), this will save different - # activations. Editing the forward pass code to save activations is the way to go for these cases. - for name, m in net.named_modules(): - if type(m)==nn.Conv2d or type(m) == nn.Linear: - # partial to assign the layer name to each hook - m.register_forward_hook(partial(save_activation, activations, name)) - - # forward pass through the full dataset - out = [] - for tensor_image in tensor_images: - out.append(net(tensor_image).detach().cpu().numpy()) - - # concatenate all the outputs we saved to get the the activations for each layer for the whole dataset - activations_dict = {name: torch.cat(outputs, 0).cpu().detach().numpy() for name, outputs in activations.items()} - return activations_dict, out - -def get_layer_activations(activations_dict, layer_name): - layer_activation = None - for name, activation in activations_dict.items(): - if name == layer_name: - layer_activation = activation - return layer_activation - -def project_layer_activations_to_input_rescale(layer_activation, input_shape): - """ - Projects the nth activation and the cth channel from layer - to input. layer_activation[n,c,:,:] -> Input - """ - act_shape = np.shape(layer_activation) - n = act_shape[0] - c = act_shape[1] - h = act_shape[2] - w = act_shape[3] - - samples = [i for i in range(n)] - channels = [c for c in range(c)] - - canvas = np.zeros([len(samples), len(channels), input_shape[0], input_shape[1]], - dtype=np.float32) - - for n in samples: - for c in channels: - to_project = layer_activation[n,c,:,:] - canvas[n,c,:,:] = cv2.resize(to_project, (input_shape[1], input_shape[0])) - - return canvas diff --git a/dac/attribute.py b/dac/attribute.py deleted file mode 100644 index 2892b6e..0000000 --- a/dac/attribute.py +++ /dev/null @@ -1,253 +0,0 @@ -from captum.attr import IntegratedGradients, Saliency, DeepLift,\ - GuidedGradCam, InputXGradient,\ - DeepLift, LayerGradCam, GuidedBackprop -import torch -import numpy as np -import os -import scipy -import scipy.ndimage -import sys - -from dac.utils import save_image, normalize_image, image_to_tensor -from dac.activations import project_layer_activations_to_input_rescale -from dac.stereo_gc import get_sgc -from dac_networks import init_network - -torch.manual_seed(123) -np.random.seed(123) - -def get_attribution(real_img, - fake_img, - real_class, - fake_class, - net_module, - checkpoint_path, - input_shape, - channels, - methods=["ig", "grads", "gc", "ggc", "dl", "ingrad", "random", "residual"], - output_classes=6, - downsample_factors=None, - bidirectional=False): - - '''Return (discriminative) attributions for an image pair. - - Args: - - real_img: (''array like'') - - Real image to run attribution on. - - - fake_img: (''array like'') - - Counterfactual image typically created by a cycle GAN. - - real_class: (''int'') - - Class index of real image. Must correspond to networks output class. - - fake_class: (''int'') - - Class index of fake image. Must correspond to networks output class. - - net_module: (''str'') - - Name of network to use. Network is assumed to be specified at - networks/{net_module}.py and have a matching class name. - - checkpoint_path: (''str'') - - Path to network checkpoint - - input_shape: (''tuple of int'') - - Spatial input image shape, must be 2D. - - channels: (''int'') - - Number of input channels - - methods: (''list of str'') - - List of attribution methods to run - - output_classes: (''int'') - - Number of network output classes - - downsample_factors: (''List of tuple of int'') - - Network argument specifying downsample factors - - bidirectional: (''int'') - - Return both attribution directions. - ''' - - imgs = [image_to_tensor(normalize_image(real_img).astype(np.float32)), - image_to_tensor(normalize_image(fake_img).astype(np.float32))] - - classes = [real_class, fake_class] - net = init_network(checkpoint_path, input_shape, net_module, channels, output_classes=output_classes,eval_net=True, require_grad=False, - downsample_factors=downsample_factors) - - attrs = [] - attrs_names = [] - - if "residual" in methods: - res = np.abs(real_img - fake_img) - res = res - np.min(res) - attrs.append(torch.tensor(res/np.max(res))) - attrs_names.append("residual") - - if "random" in methods: - rand = np.abs(np.random.randn(*np.shape(real_img))) - rand = np.abs(scipy.ndimage.filters.gaussian_filter(rand, 4)) - rand = rand - np.min(rand) - rand = rand/np.max(np.abs(rand)) - attrs.append(torch.tensor(rand)) - attrs_names.append("random") - - if "gc" in methods: - net.zero_grad() - last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = last_conv_layer[0] - layer = last_conv_layer[1] - layer_gc = LayerGradCam(net, layer) - gc_real = layer_gc.attribute(imgs[0], target=classes[0]) - - gc_real = project_layer_activations_to_input_rescale(gc_real.cpu().detach().numpy(), (input_shape[0], input_shape[1])) - - attrs.append(torch.tensor(gc_real[0,0,:,:])) - attrs_names.append("gc") - - gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, - fake_class, net_module, checkpoint_path, - input_shape, channels, None, output_classes=output_classes, - downsample_factors=downsample_factors) - attrs.append(gc_diff_0) - attrs_names.append("d_gc") - - if bidirectional: - gc_fake = layer_gc.attribute(imgs[1], target=classes[1]) - gc_fake = project_layer_activations_to_input_rescale(gc_fake.cpu().detach().numpy(), (input_shape[0], input_shape[1])) - attrs.append(torch.tensor(gc_fake[0,0,:,:])) - attrs_names.append("gc_fake") - - attrs.append(gc_diff_1) - attrs_names.append("d_gc_inv") - - if "ggc" in methods: - net.zero_grad() - last_conv = [module for module in net.modules() if type(module) == torch.nn.Conv2d][-1] - - # Real - guided_gc = GuidedGradCam(net, last_conv) - ggc_real = guided_gc.attribute(imgs[0], target=classes[0]) - attrs.append(ggc_real[0,0,:,:]) - attrs_names.append("ggc") - - gc_diff_0, gc_diff_1 = get_sgc(real_img, fake_img, real_class, - fake_class, net_module, checkpoint_path, - input_shape, channels, None, output_classes=output_classes, - downsample_factors=downsample_factors) - - # D-gc - net.zero_grad() - gbp = GuidedBackprop(net) - gbp_real = gbp.attribute(imgs[0], target=classes[0]) - ggc_diff_0 = gbp_real[0,0,:,:] * gc_diff_0 - attrs.append(ggc_diff_0) - attrs_names.append("d_ggc") - - if bidirectional: - ggc_fake = guided_gc.attribute(imgs[1], target=classes[1]) - attrs.append(ggc_fake[0,0,:,:]) - attrs_names.append("ggc_fake") - - gbp_fake = gbp.attribute(imgs[1], target=classes[1]) - ggc_diff_1 = gbp_fake[0,0,:,:] * gc_diff_1 - attrs.append(ggc_diff_1) - attrs_names.append("d_ggc_inv") - - # IG - if "ig" in methods: - baseline = image_to_tensor(np.zeros(input_shape, dtype=np.float32)) - net.zero_grad() - ig = IntegratedGradients(net) - ig_real, delta_real = ig.attribute(imgs[0], baseline, target=classes[0], return_convergence_delta=True) - ig_diff_1, delta_diff = ig.attribute(imgs[1], imgs[0], target=classes[1], return_convergence_delta=True) - - attrs.append(ig_real[0,0,:,:]) - attrs_names.append("ig") - - attrs.append(ig_diff_1[0,0,:,:]) - attrs_names.append("d_ig") - - if bidirectional: - ig_fake, delta_fake = ig.attribute(imgs[1], baseline, target=classes[1], return_convergence_delta=True) - attrs.append(ig_fake[0,0,:,:]) - attrs_names.append("ig_fake") - - ig_diff_0, delta_diff = ig.attribute(imgs[0], imgs[1], target=classes[0], return_convergence_delta=True) - attrs.append(ig_diff_0[0,0,:,:]) - attrs_names.append("d_ig_inv") - - - # DL - if "dl" in methods: - net.zero_grad() - dl = DeepLift(net) - dl_real = dl.attribute(imgs[0], target=classes[0]) - dl_diff_1 = dl.attribute(imgs[1], baselines=imgs[0], target=classes[1]) - - attrs.append(dl_real[0,0,:,:]) - attrs_names.append("dl") - - attrs.append(dl_diff_1[0,0,:,:]) - attrs_names.append("d_dl") - - if bidirectional: - dl_fake = dl.attribute(imgs[1], target=classes[1]) - attrs.append(dl_fake[0,0,:,:]) - attrs_names.append("dl_fake") - - dl_diff_0 = dl.attribute(imgs[0], baselines=imgs[1], target=classes[0]) - attrs.append(dl_diff_0[0,0,:,:]) - attrs_names.append("d_dl_inv") - - # INGRAD - if "ingrad" in methods: - net.zero_grad() - saliency = Saliency(net) - grads_real = saliency.attribute(imgs[0], - target=classes[0]) - grads_fake = saliency.attribute(imgs[1], - target=classes[1]) - - - net.zero_grad() - input_x_gradient = InputXGradient(net) - ingrad_real = input_x_gradient.attribute(imgs[0], target=classes[0]) - - ingrad_diff_0 = grads_fake * (imgs[0] - imgs[1]) - - attrs.append(torch.abs(ingrad_real[0,0,:,:])) - attrs_names.append("ingrad") - - attrs.append(torch.abs(ingrad_diff_0[0,0,:,:])) - attrs_names.append("d_ingrad") - - if bidirectional: - ingrad_fake = input_x_gradient.attribute(imgs[1], target=classes[1]) - attrs.append(torch.abs(ingrad_fake[0,0,:,:])) - attrs_names.append("ingrad_fake") - - ingrad_diff_1 = grads_real * (imgs[1] - imgs[0]) - attrs.append(torch.abs(ingrad_diff_1[0,0,:,:])) - attrs_names.append("d_ingrad_inv") - - attrs = [a.detach().cpu().numpy() for a in attrs] - attrs_norm = [a/np.max(np.abs(a)) for a in attrs] - - return attrs_norm, attrs_names diff --git a/dac/dataset.py b/dac/dataset.py deleted file mode 100644 index 6c9346e..0000000 --- a/dac/dataset.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -import os -from shutil import copy -import itertools - -from dac.utils import open_image - - -def parse_predictions(prediction_dir, - real_class, - fake_class): - '''Parse cycle-GAN predictions from prediction dir. - - Args: - - prediction_dir: (''str'') - - Path to cycle-GAN prediction dir - - real_class: (''int'') - - Real class output index - - fake_class: (''int'') - - Fake class output index - ''' - - files = [os.path.join(prediction_dir, f) for f in os.listdir(prediction_dir)] - real_imgs = [f for f in files if f.endswith("real.png")] - fake_imgs = [f for f in files if f.endswith("fake.png")] - pred_files = [f for f in files if f.endswith("aux.json")] - - img_ids = [int(f.split("/")[-1].split("_")[0]) for f in real_imgs] - - ids_to_data = {} - for img_id in img_ids: - real = [f for f in real_imgs if img_id == int(f.split("/")[-1].split("_")[0])] - fake = [f for f in fake_imgs if img_id == int(f.split("/")[-1].split("_")[0])] - aux = [f for f in pred_files if img_id == int(f.split("/")[-1].split("_")[0])] - assert(len(real) == 1) - assert(len(fake) == 1) - assert(len(aux) == 1) - - real = real[0] - fake = fake[0] - aux = aux[0] - aux_data = json.load(open(aux, "r")) - aux_real = aux_data["aux_real"][real_class] - aux_fake = aux_data["aux_fake"][fake_class] - - ids_to_data[img_id] = (real, fake, aux_real, aux_fake) - - return ids_to_data - -def create_filtered_dataset(ids_to_data, data_dir, threshold=0.8): - '''Filter out failed translations (f(x) threshold) and (data[3] > threshold): - copy(data[0], os.path.join(data_dir + f"/real_{idx}.png")) - copy(data[1], os.path.join(data_dir + f"/fake_{idx}.png")) - idx += 1 diff --git a/dac/gradients.py b/dac/gradients.py deleted file mode 100644 index 30a9d2a..0000000 --- a/dac/gradients.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from functools import partial - -def hook_fn(in_grads, out_grads, m, i, o): - for grad in i: - try: - in_grads.append(grad) - except AttributeError: - pass - - for grad in o: - try: - out_grads.append(grad.cpu().numpy()) - except AttributeError: - pass - -def get_gradients_from_layer(net, x, y, layer_name=None, normalize=False): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - xx = torch.tensor(x, device=device).unsqueeze(0) - yy = torch.tensor([y], device=device) - xx = xx.unsqueeze(0) - in_grads = [] - out_grads = [] - try: - for param in net.features.parameters(): - param.requires_grad = True - except AttributeError: - for param in net.parameters(): - param.requires_grad = True - - if layer_name is None: - layers = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = layers[0] - layer = layers[1] - else: - layers = [module for name, module in net.named_modules() if name == layer_name] - assert(len(layers) == 1) - layer = layers[0] - - layer.register_backward_hook(partial(hook_fn, in_grads, out_grads)) - - out = net(xx) - out[0][y].backward() - grad = out_grads[0] - if normalize: - max_grad = np.max(np.abs(grad)) - if max_grad>10**(-12): - grad /= max_grad - else: - grad = np.zeros(np.shape(grad)) - - return grad diff --git a/dac/mask.py b/dac/mask.py deleted file mode 100644 index 87ce625..0000000 --- a/dac/mask.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np -import cv2 -import copy - -from dac.utils import normalize_image, save_image -from dac_networks import run_inference, init_network - -def get_mask(attribution, real_img, fake_img, real_class, fake_class, - net_module, checkpoint_path, input_shape, input_nc, output_classes, - downsample_factors=None, sigma=11, struc=10): - """ - attribution: 2D array <= 1 indicating pixel importance - """ - - net = init_network(checkpoint_path, input_shape, net_module, input_nc, eval_net=True, require_grad=False, output_classes=output_classes, - downsample_factors=downsample_factors) - result_dict = {} - img_names = ["attr", "real", "fake", "hybrid", "mask_real", "mask_fake", "mask_residual", "mask_weight"] - imgs_all = [] - - a_min = -1 - a_max = 1 - steps = 200 - a_range = a_max - a_min - step = a_range/float(steps) - for k in range(0,steps+1): - thr = a_min + k * step - copyfrom = copy.deepcopy(real_img) - copyto = copy.deepcopy(fake_img) - copyto_ref = copy.deepcopy(fake_img) - copied_canvas = np.zeros(np.shape(copyfrom)) - mask = np.array(attribution > thr, dtype=np.uint8) - - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(struc,struc)) - mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) - mask_size = np.sum(mask) - mask_cp = copy.deepcopy(mask) - - mask_weight = cv2.GaussianBlur(mask_cp.astype(np.float), (sigma,sigma),0) - copyto = np.array((copyto * (1 - mask_weight)) + (copyfrom * mask_weight), dtype=np.float) - - copied_canvas += np.array(mask_weight*copyfrom) - copied_canvas_to = np.zeros(np.shape(copyfrom)) - copied_canvas_to += np.array(mask_weight*copyto_ref) - diff_copied = copied_canvas - copied_canvas_to - - fake_img_norm = normalize_image(copy.deepcopy(fake_img)) - out_fake = run_inference(net, fake_img_norm) - - real_img_norm = normalize_image(copy.deepcopy(real_img)) - out_real = run_inference(net, real_img_norm) - - im_copied_norm = normalize_image(copy.deepcopy(copyto)) - out_copyto = run_inference(net, im_copied_norm) - - imgs = [attribution, real_img_norm, fake_img_norm, im_copied_norm, normalize_image(copied_canvas), - normalize_image(copied_canvas_to), normalize_image(diff_copied), mask_weight] - - imgs_all.append(imgs) - - mrf_score = out_copyto[0][real_class] - out_fake[0][real_class] - result_dict[thr] = [float(mrf_score.detach().cpu().numpy()), mask_size] - - return result_dict, img_names, imgs_all diff --git a/dac/stereo_gc.py b/dac/stereo_gc.py deleted file mode 100644 index 50a7b51..0000000 --- a/dac/stereo_gc.py +++ /dev/null @@ -1,97 +0,0 @@ -import collections -import numpy as np -import os -import torch - -from dac.gradients import get_gradients_from_layer -from dac.activations import get_activation_dict, get_layer_activations, project_layer_activations_to_input_rescale -from dac.utils import normalize_image, save_image -from dac_networks import run_inference, init_network - -def get_sgc(real_img, fake_img, real_class, fake_class, - net_module, checkpoint_path, input_shape, - input_nc, layer_name=None, output_classes=6, - downsample_factors=None): - """ - real_img: Unnormalized (0-255) 2D image - - fake_img: Unnormalized (0-255) 2D image - - *_class: Index of real and fake class corresponding to network output - - net_module: Name of file and class name of the network to use. Must be placed in networks subdirectory - - checkpoint_path: Checkpoint of network. - - input_shape: Spatial input shape of network - - input_nc: Number of input channels. - - layer_name: Name of the conv layer to use (defaults to last) - - output_classes: Number of network output classes - - downsample_factors: Network downsample factors - """ - - - if len(np.shape(fake_img)) != len(np.shape(real_img)) !=2: - raise ValueError("Input images need to be two dimensional") - - imgs = [normalize_image(real_img), normalize_image(fake_img)] - classes = [real_class, fake_class] - - if layer_name is None: - net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, - output_classes=output_classes, - downsample_factors=downsample_factors) - last_conv_layer = [(name,module) for name, module in net.named_modules() if type(module) == torch.nn.Conv2d][-1] - layer_name = last_conv_layer[0] - layer = last_conv_layer[1] - - grads = [] - for x,y in zip(imgs,classes): - grad_net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, - output_classes=output_classes, - downsample_factors=downsample_factors) - grads.append(get_gradients_from_layer(grad_net, x, y, layer_name)) - - acts_real = collections.defaultdict(list) - acts_fake = collections.defaultdict(list) - - activation_net = init_network(checkpoint_path, input_shape, net_module, - input_nc, eval_net=True, require_grad=False, output_classes=output_classes, - downsample_factors=downsample_factors) - - acts_real, out_real = get_activation_dict(activation_net, [imgs[0]], acts_real) - acts_fake, out_fake = get_activation_dict(activation_net, [imgs[1]], acts_fake) - - acts = [acts_real, acts_fake] - outs = [out_real, out_fake] - - layer_acts = [] - for act in acts: - layer_acts.append(get_layer_activations(act, layer_name)) - - delta_fake = grads[1] * (layer_acts[0] - layer_acts[1]) - delta_real = grads[0] * (layer_acts[1] - layer_acts[0]) - - delta_fake_projected = project_layer_activations_to_input_rescale(delta_fake, (input_shape[0], input_shape[1]))[0,:,:,:] - delta_real_projected = project_layer_activations_to_input_rescale(delta_real, (input_shape[0], input_shape[1]))[0,:,:,:] - - channels = np.shape(delta_fake_projected)[0] - gc_0 = np.zeros(np.shape(delta_fake_projected)[1:]) - gc_1 = np.zeros(np.shape(delta_real_projected)[1:]) - - for c in range(channels): - gc_0 += delta_fake_projected[c,:,:] - gc_1 += delta_real_projected[c,:,:] - - gc_0 = np.abs(gc_0) - gc_1 = np.abs(gc_1) - gc_0 /= np.max(np.abs(gc_0)) - gc_1 /= np.max(np.abs(gc_1)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - return torch.tensor(gc_0, device=device), torch.tensor(gc_1, device=device) diff --git a/dac/utils.py b/dac/utils.py deleted file mode 100644 index b058b26..0000000 --- a/dac/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -import os -from PIL import Image -import torch - -def flatten_image(pil_image): - """ - pil_image: image as returned from PIL Image - """ - return np.array(pil_image[:,:,0], dtype=np.float32) - -def normalize_image(image): - """ - image: 2D input image - """ - return (image.astype(np.float32)/255. - 0.5)/0.5 - -def open_image(image_path, flatten=True, normalize=True): - im = np.asarray(Image.open(image_path)) - if flatten: - im = flatten_image(im) - if normalize: - im = normalize_image(im) - return im - -def image_to_tensor(image): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - image_tensor = torch.tensor(image, device=device) - image_tensor = image_tensor.unsqueeze(0).unsqueeze(0) - return image_tensor - -def save_image(array, image_path, renorm=True, norm=False): - if renorm: - array = (array *0.5 + 0.5)*255 - if norm: - array/=np.max(np.abs(array)) - array *= 255 - - im = Image.fromarray(array) - im = im.convert('RGB') - im.save(image_path) - -def get_all_pairs(classes): - pairs = [] - i = 0 - for i in range(len(classes)): - for k in range(i+1, len(classes)): - pair = (classes[i], classes[k]) - pairs.append(pair) - - return pairs - -def get_image_pairs(base_dir, class_0, class_1): - """ - Experiment datasets are expected to be placed at - /_ - """ - image_dir = f"{base_dir}/{class_0}_{class_1}" - images = os.listdir(image_dir) - real = [os.path.join(image_dir,im) for im in images if "real" in im and im.endswith(".png")] - fake = [os.path.join(image_dir,im) for im in images if "fake" in im and im.endswith(".png")] - paired_images = [] - for r in real: - for f in fake: - if r.split("/")[-1].split("_")[-1] == f.split("/")[-1].split("_")[-1]: - paired_images.append((r,f)) - break - - return paired_images diff --git a/dac_networks/.Vgg2D.py.swp b/dac_networks/.Vgg2D.py.swp deleted file mode 100644 index caa6e5f..0000000 Binary files a/dac_networks/.Vgg2D.py.swp and /dev/null differ diff --git a/dac_networks/ResNet.py b/dac_networks/ResNet.py deleted file mode 100644 index 6184a1f..0000000 --- a/dac_networks/ResNet.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -from torch import nn -import math - -class ResNet(nn.Module): - def __init__(self, output_classes, input_size=(128,128), input_channels=1): - super(ResNet, self).__init__() - self.in_channels = 12 - size = input_size[0] - self.conv = nn.Conv2d(input_channels, self.in_channels, kernel_size=3, - padding=1, stride=1, bias=True) - self.bn = nn.BatchNorm2d(self.in_channels) - self.relu = nn.ReLU() - - current_channels = self.in_channels - self.layer1 = self.make_layer(ResidualBlock, current_channels, 2, 2) - current_channels *= 2 - size /= 2 - self.layer2 = self.make_layer(ResidualBlock, current_channels, 2, 2) - current_channels *= 2 - size /= 2 - self.layer3 = self.make_layer(ResidualBlock, current_channels, 2, 2) - current_channels *= 2 - size /= 2 - self.layer4 = self.make_layer(ResidualBlock, current_channels, 2, 2) - size /= 2 - size = int(math.ceil(size)) - - fc = [torch.nn.Linear(current_channels*size**2, 4096), - torch.nn.ReLU(), - torch.nn.Dropout(), - torch.nn.Linear(4096, 4096), - torch.nn.ReLU(), - torch.nn.Dropout(), - torch.nn.Linear(4096,output_classes)] - - self.fc = torch.nn.Sequential(*fc) - print(self) - - def make_layer(self, block, out_channels, blocks, stride=1): - downsample = None - if (stride != 1) or self.in_channels != out_channels: - downsample = nn.Sequential( - nn.Conv2d(self.in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=True), - nn.BatchNorm2d(out_channels)) - layers = [] - layers.append(block(self.in_channels, out_channels, stride, downsample)) - self.in_channels = out_channels - for i in range(1, blocks): - layers.append(block(out_channels, out_channels)) - return nn.Sequential(*layers) - - def forward(self, x): - out = self.conv(x) - out = self.bn(out) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = out.view(out.size(0), -1) - out = self.fc(out) - return out - -# Residual block -class ResidualBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride=1, downsample=None): - super(ResidualBlock, self).__init__() - # Biases are handled by BN layers - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, - padding=1, stride=stride, bias=True) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, - padding=1, bias=True) - self.bn2 = nn.BatchNorm2d(out_channels) - self.downsample = downsample - - def forward(self, x): - residual = x - out = self.conv1(x) - out = self.bn1(out) - out = nn.ReLU()(out) - out = self.conv2(out) - out = self.bn2(out) - if self.downsample: - residual = self.downsample(x) - out += residual - out = nn.ReLU()(out) - return out diff --git a/dac_networks/Vgg2D.py b/dac_networks/Vgg2D.py deleted file mode 100644 index 9b94a3b..0000000 --- a/dac_networks/Vgg2D.py +++ /dev/null @@ -1,93 +0,0 @@ -import torch - -class Vgg2D(torch.nn.Module): - def __init__( - self, - input_size, - input_channels, - fmaps=12, - downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)], - output_classes=6): - - self.input_size = input_size - - super(Vgg2D, self).__init__() - - current_fmaps = 1 - current_size = tuple(input_size) - - features = [] - for i in range(len(downsample_factors)): - - features += [ - torch.nn.Conv2d( - current_fmaps, - fmaps, - kernel_size=3, - padding=1), - torch.nn.BatchNorm2d(fmaps), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d( - fmaps, - fmaps, - kernel_size=3, - padding=1), - torch.nn.BatchNorm2d(fmaps), - torch.nn.ReLU(inplace=True), - torch.nn.MaxPool2d(downsample_factors[i]) - ] - - current_fmaps = fmaps - fmaps *= 2 - - size = tuple( - int(c/d) - for c, d in zip(current_size, downsample_factors[i])) - check = ( - s*d == c - for s, d, c in zip(size, downsample_factors[i], current_size)) - assert all(check), \ - "Can not downsample %s by chosen downsample factor" % \ - (current_size,) - current_size = size - - self.features = torch.nn.Sequential(*features) - - classifier = [ - torch.nn.Linear( - current_size[0] * - current_size[1] * - current_fmaps, - 4096), - torch.nn.ReLU(inplace=True), - torch.nn.Dropout(), - torch.nn.Linear( - 4096, - 4096), - torch.nn.ReLU(inplace=True), - torch.nn.Dropout(), - torch.nn.Linear( - 4096, - output_classes) - ] - - self.classifier = torch.nn.Sequential(*classifier) - - #print(self) - - def forward(self, raw): - shape = tuple(raw.shape) - if shape[1] != 1: #== (1, 3, 128, 128) - rgb conversion needed for captum. - raw = raw[:,0,:,:].reshape(shape[0],1,shape[2], shape[3]) - - raw_with_channels = raw.reshape( - shape[0], - 1, - shape[2], - shape[3]) - - raw_with_channels = raw - f = self.features(raw_with_channels) - f = f.view(f.size(0), -1) - y = self.classifier(f) - return y diff --git a/dac_networks/__init__.py b/dac_networks/__init__.py deleted file mode 100644 index 6e0685a..0000000 --- a/dac_networks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .network_utils import init_network, run_inference diff --git a/dac_networks/network_utils.py b/dac_networks/network_utils.py deleted file mode 100644 index 3a6a093..0000000 --- a/dac_networks/network_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import importlib -import torch -import torch.nn.functional as F - -from dac.utils import image_to_tensor - -def init_network(checkpoint_path=None, input_shape=(128,128), net_module="Vgg2D", - input_nc=1, output_classes=6, gpu_ids=[], eval_net=True, require_grad=False, - downsample_factors=None): - """ - checkpoint_path: Path to train checkpoint to restore weights from - - input_nc: input_channels for aux net - - aux_net: name of aux net - """ - net_mod = importlib.import_module(f"dac_networks.{net_module}") - net_class = getattr(net_mod, f'{net_module}') - if net_module == "Vgg2D": - net = net_class(input_size=input_shape, input_channels=input_nc, output_classes=output_classes, - downsample_factors=downsample_factors) - else: - net = net_class(input_size=input_shape, input_channels=input_nc, output_classes=output_classes) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - net.to(device) - - if eval_net: - net.eval() - - if require_grad: - for param in net.parameters(): - param.requires_grad = True - else: - for param in net.parameters(): - param.requires_grad = False - - if checkpoint_path is not None: - checkpoint = torch.load(checkpoint_path, map_location=device) - try: - net.load_state_dict(checkpoint['model_state_dict']) - except KeyError: - net.load_state_dict(checkpoint) - return net - -def run_inference(net, im): - """ - Net: network object - input_image: Normalized 2D input image. - """ - im_tensor = image_to_tensor(im) - class_probs = F.softmax(net(im_tensor), dim=1) - return class_probs diff --git a/exercise.ipynb b/exercise.ipynb index 064639f..92007b0 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,798 +2,213 @@ "cells": [ { "cell_type": "markdown", - "id": "93f0f7f9", + "id": "30c11df5", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Exercise 9: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Pre-trained Neural Network\n", "\n", - "In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data.\n", + "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", - "![synister.png](assets/synister.png)\n", + "We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.\n", + "Unlike regular MNIST, our dataset is classified not by number, but by color!\n", "\n", - "### Acknowledgments\n", + "We will:\n", + "1. Load a pre-trained classifier and try applying conventional attribution methods\n", + "2. Train a GAN to create counterfactual images - translating images from one class to another\n", + "3. Evaluate the GAN - see how good it is at fooling the classifier\n", + "4. Create attributions from the counterfactual, and learn the differences between the classes.\n", "\n", - "This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9a25e710", - "metadata": {}, - "source": [ - "
\n", - "Set your python kernel to 09_knowledge_extraction\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "f9b96c13", - "metadata": {}, - "source": [ - "

Start here (AKA checkpoint 0)

\n", + "If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem.\n", + "### Acknowledgments\n", "\n", - "
" + "This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein.\n" ] }, { "cell_type": "markdown", - "id": "0c339e3d", + "id": "ec2899d4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "# Part 1: Image Classification\n", - "\n", - "## Training an image classifier\n", - "In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine." + "
\n", + "Set your python kernel to 08_knowledge_extraction\n", + "
" ] }, { "cell_type": "markdown", - "id": "7f524106", + "id": "2c084b97", "metadata": {}, "source": [ "\n", - "The data we use for this exercise is located in `data/raw/synapses`, where we have one subdirectory for each neurotransmitter type. Look at a few examples to familiarize yourself with the dataset. You will notice that the dataset is not balanced, i.e., we have much more examples of one class versus another one.\n", - "\n", - "This class imbalance is problematic for training a classifier. Imagine that 99% of our images are of one class, then the classifier would do really well predicting this class all the time, without having learnt anything of substance. It is therefore important to balance the dataset, i.e., present the same number of images per class to the classifier during training.\n", - "\n", - "First, we split the available images into a train, validation, and test dataset with proportions of 0.7, 0.15, and 0.15, respectively. Each image should be returned as a 2D `numpy` array with float values between 0 and 1. The label for each image should be the name of the directory for this class (e.g., `0_gaba`).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dca1c9b7", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, - "outputs": [], - "source": [ - "from torch.utils.data import DataLoader, random_split\n", - "from torch.utils.data.sampler import WeightedRandomSampler\n", - "from torchvision.datasets import ImageFolder\n", - "from torchvision import transforms\n", - "import torch\n", - "import numpy as np\n", - "\n", - "transform = transforms.Compose(\n", - " [\n", - " transforms.Grayscale(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.5,), (0.5,)),\n", - " ]\n", - ")\n", - "\n", - "# create a dataset for all images of all classes\n", - "full_dataset = ImageFolder(root=\"data/raw/synapses\", transform=transform)\n", - "\n", - "# Rename the classes\n", - "full_dataset.classes = [x.split(\"_\")[-1] for x in full_dataset.classes]\n", - "class_to_idx = {x.split(\"_\")[-1]: v for x, v in full_dataset.class_to_idx.items()}\n", - "full_dataset.class_to_idx = class_to_idx\n", - "\n", - "# randomly split the dataset into train, validation, and test\n", - "num_images = len(full_dataset)\n", - "# ~70% for training\n", - "num_training = int(0.7 * num_images)\n", - "# ~15% for validation\n", - "num_validation = int(0.15 * num_images)\n", - "# ~15% for testing\n", - "num_test = num_images - (num_training + num_validation)\n", - "# split the data randomly (but with a fixed random seed)\n", - "train_dataset, validation_dataset, test_dataset = random_split(\n", - " full_dataset,\n", - " [num_training, num_validation, num_test],\n", - " generator=torch.Generator().manual_seed(23061912),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2f4f148f", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "### Creating a Balanced Dataloader\n", - "\n", - "Below define a `sampler` that samples images of classes with skewed probabilities to account for the different number of items in each class.\n", - "\n", - "The sampler\n", - "- Counts the number of samples in each class\n", - "- Gets the weight-per-label as an inverse of the frequency\n", - "- Get the weight-per-sample\n", - "- Create a `WeightedRandomSampler` based on these weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faa2b411", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# compute class weights in training dataset for balanced sampling\n", - "def balanced_sampler(dataset):\n", - " # Get a list of targets from the dataset\n", - " if isinstance(dataset, torch.utils.data.Subset):\n", - " # A Subset is a specific type of dataset, which does not directly have access to the targets.\n", - " targets = torch.tensor(dataset.dataset.targets)[dataset.indices]\n", - " else:\n", - " targets = dataset.targets\n", - "\n", - " counts = torch.bincount(targets) # Count the number of samples for each class\n", - " label_weights = (\n", - " 1.0 / counts\n", - " ) # Get the weight-per-label as an inverse of the frequency\n", - " weights = label_weights[targets] # Get the weight-per-sample\n", - "\n", - " # Optional: Print the Counts and Weights to make sure lower frequency classes have higher weights\n", - " print(\"Number of images per class:\")\n", - " for c, n, w in zip(full_dataset.classes, counts, label_weights):\n", - " print(f\"\\t{c}:\\tn={n}\\tweight={w}\")\n", - "\n", - " sampler = WeightedRandomSampler(\n", - " weights, len(weights)\n", - " ) # Create a sampler based on these weights\n", - " return sampler\n", + "# Part 1: Setup\n", "\n", - "\n", - "sampler = balanced_sampler(train_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "ceb0525e", - "metadata": {}, - "source": [ - "We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n", - "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a15b4bac", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# this data loader will serve 8 images in a \"mini-batch\" at a time\n", - "dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)" - ] - }, - { - "cell_type": "markdown", - "id": "5892ab7f", - "metadata": {}, - "source": [ - "The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset and that your sampler gives batches of evenly distributed synapse types." + "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", + "We will also learn to load one of our trained classifiers from a checkpoint." ] }, { "cell_type": "code", "execution_count": null, - "id": "5aab255a", + "id": "9d26a8bb", "metadata": { - "lines_to_next_cell": 2, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "\n", + "# loading the data\n", + "from classifier.data import ColoredMNIST\n", "\n", - "def show_batch(x, y):\n", - " fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)\n", - " for i in range(x.shape[0]):\n", - " axs[i].imshow(np.squeeze(x[i]), cmap=\"gray\", vmin=-1, vmax=1)\n", - " axs[i].set_title(train_dataset.dataset.classes[y[i].item()])\n", - " axs[i].axis(\"off\")\n", - " plt.show()\n", - "\n", - "\n", - "# show a random batch from the data loader\n", - "# (run this cell repeatedly to see different batches)\n", - "for x, y in dataloader:\n", - " show_batch(x, y)\n", - " break" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "025648fb", + "id": "f8a5937c", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ - "### Creating a VGG Network, Loss\n", - "\n", - "We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.\n", + "Some information about the dataset:\n", + "- The dataset is a colored version of the MNIST dataset.\n", + "- Instead of using the digits as classes, we use the colors.\n", + "- There are four classes - the goal of the exercise is to find out what these are.\n", "\n", - "We have implemented a VGG network below.\n", - "" + "Let's plot some examples" ] }, { "cell_type": "code", "execution_count": null, - "id": "e7e2b968", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Vgg2D(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " input_size,\n", - " fmaps=12,\n", - " downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],\n", - " output_classes=6,\n", - " ):\n", - " super(Vgg2D, self).__init__()\n", - "\n", - " self.input_size = input_size\n", - "\n", - " current_fmaps, h, w = tuple(input_size)\n", - " current_size = (h, w)\n", - "\n", - " features = []\n", - " for i in range(len(downsample_factors)):\n", - " features += [\n", - " torch.nn.Conv2d(current_fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Conv2d(fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.MaxPool2d(downsample_factors[i]),\n", - " ]\n", - "\n", - " current_fmaps = fmaps\n", - " fmaps *= 2\n", - "\n", - " size = tuple(\n", - " int(c / d) for c, d in zip(current_size, downsample_factors[i])\n", - " )\n", - " check = (\n", - " s * d == c for s, d, c in zip(size, downsample_factors[i], current_size)\n", - " )\n", - " assert all(check), \"Can not downsample %s by chosen downsample factor\" % (\n", - " current_size,\n", - " )\n", - " current_size = size\n", - "\n", - " self.features = torch.nn.Sequential(*features)\n", - "\n", - " classifier = [\n", - " torch.nn.Linear(current_size[0] * current_size[1] * current_fmaps, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, output_classes),\n", - " ]\n", - "\n", - " self.classifier = torch.nn.Sequential(*classifier)\n", - "\n", - " def forward(self, raw):\n", - " # compute features\n", - " f = self.features(raw)\n", - " f = f.view(f.size(0), -1)\n", - "\n", - " # classify\n", - " y = self.classifier(f)\n", - "\n", - " return y" - ] - }, - { - "cell_type": "markdown", - "id": "c544bd0d", + "id": "9c0ce960", "metadata": {}, - "source": [ - "We'll start by creating the VGG with the default parameters and push it to a GPU if there is one available. Then we'll define the training loss and optimizer.\n", - "The training and evaluation loops have been defined for you, so after that just train your network!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c6fca99", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# get the size of our images\n", - "for x, y in train_dataset:\n", - " input_size = x.shape\n", - " break\n", - "\n", - "# create the model to train\n", - "model = Vgg2D(input_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4929dd7f", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "# use a GPU, if it is available\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model.to(device)\n", - "print(f\"Will use device {device} for training\")" - ] - }, - { - "cell_type": "markdown", - "id": "73e2d8ad", - "metadata": {}, - "source": [ - "

Task 1.1: Train the VGG Network

\n", + "import matplotlib.pyplot as plt\n", "\n", - "- Choose a loss\n", - "- Create an Adam optimizer and set its learning rate\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c29af1d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "loss = ...\n", - "optimizer = ..." + "# Show some examples\n", + "fig, axs = plt.subplots(4, 4, figsize=(8, 8))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " x, y = mnist[i]\n", + " x = x.permute((1, 2, 0)) # make channels last\n", + " ax.imshow(x)\n", + " ax.set_title(f\"Class {y}\")\n", + " ax.axis(\"off\")" ] }, { "cell_type": "markdown", - "id": "6fb96afe", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "The next cell defines some convenience functions for training, validation, and testing:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1f21c05", + "id": "0cb834e5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "from tqdm import tqdm\n", - "\n", - "\n", - "def train(dataloader):\n", - " \"\"\"Train the model for one epoch.\"\"\"\n", - "\n", - " # set the model into train mode\n", - " model.train()\n", - "\n", - " epoch_loss = 0\n", - "\n", - " num_batches = 0\n", - " for x, y in tqdm(dataloader, \"train\"):\n", - " x, y = x.to(device), y.to(device)\n", - " optimizer.zero_grad()\n", - "\n", - " y_pred = model(x)\n", - " l = loss(y_pred, y)\n", - " l.backward()\n", - "\n", - " optimizer.step()\n", - "\n", - " epoch_loss += l\n", - " num_batches += 1\n", - "\n", - " return epoch_loss / num_batches" + "We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`.\n", + "Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "9c473df0", + "id": "a32035d7", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

Task 1.2: Create a prediction function

\n", - "\n", - "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n", - "\n", - "\n", - "TODO\n", - "Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", - "- Get the model output for the batch of data `(x, y)`\n", - "- Turn the model output into a probability\n", - "- Get the class predictions from the probabilities\n", - "- Add the class predictions to a list of all predictions\n", - "- Add the ground truths to a list of all ground truths\n", + "

Task 1.1: Load the classifier

\n", + "We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:\n", + "- `input_shape`: the shape of the input images, as a tuple\n", + "- `num_classes`: the number of classes in the dataset\n", "\n", - "
\n" + "Create a dense model with the right inputs and load the weights from the checkpoint.\n", + "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "cae63f62", + "id": "47684cce", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ - "# TODO: return a paired list of predicted class vs ground-truth to produce a confusion matrix\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - " #\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " # Get the model output\n", - " # Turn the model output into a probability\n", - " # Get the class predictions from the probabilities\n", + "import torch\n", + "from classifier.model import DenseModel\n", "\n", - " predictions.extend(...) # TODO add predictions to the list\n", - " ground_truths.extend(...) # TODO add ground truths to the list\n", - " return np.array(predictions), np.array(ground_truths)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", + "# TODO Load the model with the correct input shape\n", + "model = DenseModel(input_shape=(...), num_classes=4)\n", "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" + "# TODO modify this with the location of your classifier checkpoint\n", + "checkpoint = torch.load(...)\n", + "model.load_state_dict(checkpoint)\n", + "model = model.to(device)" ] }, { "cell_type": "markdown", - "id": "bfee4910", + "id": "6ecddeb8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy." + "Don't take my word for it! Let's see how well the classifier does on the test set." ] }, { "cell_type": "code", "execution_count": null, - "id": "41bc31bd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "for epoch in range(3):\n", - " epoch_loss = train(dataloader)\n", - " print(f\"Epoch {epoch}, training loss={epoch_loss}\")\n", - "\n", - " predictions, gt = predict(validation_dataset, \"Validation\")\n", - " accuracy = accuracy_score(gt, predictions)\n", - " print(f\"Epoch {epoch}, validation accuracy={accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc91973f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Let's watch your model train!\n", - "\n", - "\"drawing\"" - ] - }, - { - "cell_type": "markdown", - "id": "7324a440", + "id": "c271ecd9", "metadata": {}, - "source": [ - "And now, let's test it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0770ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "57241755", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "If you're unhappy with the accuracy above (which you should be...) we pre-trained a model for you for many more epochs. You can load it with the next cell." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "953cad3a", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, "outputs": [], "source": [ - "# TODO Run this cell if you want a shortcut\n", - "yes_I_want_the_pretrained_model = True\n", - "\n", - "if yes_I_want_the_pretrained_model:\n", - " checkpoint = torch.load(\n", - " \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n", - " )\n", - " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - "\n", - "\n", - "# And check the (hopefully much better) accuracy\n", - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "45d26644", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "### Constructing a confusion matrix\n", - "\n", - "We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n", - "\n", - "To understand the performance of the classifier beyond a single accuracy number, we should build a confusion matrix that can more elucidate which classes are more/less misclassified and which classes are those wrong predictions confused with.\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "id": "39ae027f", - "metadata": {}, - "source": [ - "Let's plot the confusion matrix." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc315793", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import pandas as pd\n", + "from torch.utils.data import DataLoader\n", "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", - "import numpy as np\n", - "\n", - "\n", - "# Plot confusion matrix\n", - "# orginally from Runqi Yang;\n", - "# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7\n", - "def cm_analysis(y_true, y_pred, names, labels=None, title=None, figsize=(10, 8)):\n", - " \"\"\"\n", - " Generate matrix plot of confusion matrix with pretty annotations.\n", - "\n", - " Parameters\n", - " ----------\n", - " confusion_matrix: np.array\n", - " labels: list\n", - " List of integer values to determine which classes to consider.\n", - " names: string array, name the order of class labels in the confusion matrix.\n", - " use `clf.classes_` if using scikit-learn models.\n", - " with shape (nclass,).\n", - " ymap: dict: any -> string, length == nclass.\n", - " if not None, map the labels & ys to more understandable strings.\n", - " Caution: original y_true, y_pred and labels must align.\n", - " figsize: the size of the figure plotted.\n", - " \"\"\"\n", - " if labels is not None:\n", - " assert len(names) == len(labels)\n", - " cm = confusion_matrix(y_true, y_pred, labels=labels)\n", - " cm_sum = np.sum(cm, axis=1, keepdims=True)\n", - " cm_perc = cm / cm_sum.astype(float) * 100\n", - " annot = np.empty_like(cm).astype(str)\n", - " nrows, ncols = cm.shape\n", - " for i in range(nrows):\n", - " for j in range(ncols):\n", - " c = cm[i, j]\n", - " p = cm_perc[i, j]\n", - " if i == j:\n", - " s = cm_sum[i]\n", - " annot[i, j] = \"%.1f%%\\n%d/%d\" % (p, c, s)\n", - " elif c == 0:\n", - " annot[i, j] = \"\"\n", - " else:\n", - " annot[i, j] = \"%.1f%%\\n%d\" % (p, c)\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " ax = sns.heatmap(\n", - " cm_perc, annot=annot, fmt=\"\", vmax=100, xticklabels=names, yticklabels=names\n", - " )\n", - " ax.set_xlabel(\"Predicted\")\n", - " ax.set_ylabel(\"True\")\n", - " if title:\n", - " ax.set_title(title)\n", - "\n", - "\n", - "names = [\"gaba\", \"acetylcholine\", \"glutamate\", \"serotonine\", \"octopamine\", \"dopamine\"]\n", - "cm_analysis(predictions, ground_truths, names=names)" - ] - }, - { - "cell_type": "markdown", - "id": "3c8cf7bb", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What observations can we make from the confusion matrix?\n", - "- Does the classifier do better on some synapse classes than other?\n", - "- If you have time later, which ideas would you try to train a better predictor?\n", - "\n", - "Let us know your thoughts on the course chat.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "ce4ccb36", - "metadata": {}, - "source": [ - "

Checkpoint 1

\n", "\n", - "We now have:\n", - "- A classifier that is pretty good at predicting neurotransmitters from EM images.\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", - "This is surprising, since we could not (yet) have made these predictions manually! If you're skeptical, feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.\n", - "\n", - "So this is an interesting situation: The VGG network knows something we don't quite know. In the next section, we will see how we can find and then visualize the relevant differences between images of different types.\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", "\n", - "This concludes the first section. Let us know on the exercise chat if you have arrived here.\n", - "
" + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "be1f14b2", + "id": "46a684f4", "metadata": {}, "source": [ - "# Part 2: Masking the relevant part of the image\n", + "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", - "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" + "In this section we will make a first attempt at highlighting differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "41464574", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", "\n", - "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", + "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", "\n", "Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients." ] @@ -801,29 +216,26 @@ { "cell_type": "code", "execution_count": null, - "id": "af08ae72", + "id": "e5b162b7", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ - "x, y = next(iter(dataloader))\n", + "batch_size = 4\n", + "batch = []\n", + "for i in range(4):\n", + " batch.append(next(image for image in mnist if image[1] == i))\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", "y = y.to(device)" ] }, { "cell_type": "markdown", - "id": "9fbf1572", + "id": "6d418ea1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -838,13 +250,11 @@ { "cell_type": "code", "execution_count": null, - "id": "897dd327", + "id": "5ce086ee", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -861,12 +271,8 @@ { "cell_type": "code", "execution_count": null, - "id": "31fa10dc", + "id": "e4ba6b3a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], @@ -878,12 +284,9 @@ }, { "cell_type": "markdown", - "id": "657bf893", + "id": "56e432ae", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 2, "tags": [] }, "source": [ @@ -893,34 +296,27 @@ { "cell_type": "code", "execution_count": null, - "id": "7c4faa92", + "id": "9561d46f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ "from captum.attr import visualization as viz\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return 0.5 * image + 0.5\n", + "import numpy as np\n", "\n", "\n", "def visualize_attribution(attribution, original_image):\n", " attribution = np.transpose(attribution, (1, 2, 0))\n", - " original_image = np.transpose(unnormalize(original_image), (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", "\n", " viz.visualize_image_attr_multiple(\n", " attribution,\n", " original_image,\n", - " methods=[\"blended_heat_map\", \"heat_map\"],\n", - " signs=[\"absolute_value\", \"absolute_value\"],\n", + " methods=[\"original_image\", \"heat_map\"],\n", + " signs=[\"all\", \"absolute_value\"],\n", " show_colorbar=True,\n", - " titles=[\"Original and Attribution\", \"Attribution\"],\n", + " titles=[\"Image\", \"Attribution\"],\n", " use_pyplot=True,\n", " )" ] @@ -928,176 +324,90 @@ { "cell_type": "code", "execution_count": null, - "id": "4d050712", + "id": "a55fe8ec", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ - "for attr, im in zip(attributions, x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "2bd418b1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "### Smoothing the attribution into a mask\n", - "\n", - "The attributions that we see are grainy and difficult to interpret because they are a pixel-wise attribution. We apply some smoothing and thresholding on the attributions so that they represent region masks rather than pixel masks. The following code is runnable with no modification." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55715f0e", + "id": "1d8c03a0", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 2 }, - "outputs": [], "source": [ - "import cv2\n", - "import copy\n", "\n", - "\n", - "def smooth_attribution(attrs, struc=10, sigma=11):\n", - " # Morphological closing and Gaussian Blur\n", - " kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc))\n", - " mask = cv2.morphologyEx(attrs[0], cv2.MORPH_CLOSE, kernel)\n", - " mask_cp = copy.deepcopy(mask)\n", - " mask_weight = cv2.GaussianBlur(mask_cp.astype(float), (sigma, sigma), 0)\n", - " return mask_weight[np.newaxis]\n", - "\n", - "\n", - "def get_mask(attrs, threshold=0.5):\n", - " smoothed = smooth_attribution(attrs)\n", - " return smoothed > (threshold * smoothed.max())\n", - "\n", - "\n", - "def interactive_attribution(idx=0):\n", - " image = x[idx].cpu().numpy()\n", - " attrs = attributions[idx]\n", - " mask = smooth_attribution(attrs)\n", - " visualize_attribution(mask, image)\n", - " return" + "The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is.\n", + "As you can see, it is pretty good at recognizing the number within the image.\n", + "As we know, however, it is not the digit itself that is important for the classification, it is the color!\n", + "Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters." ] }, { "cell_type": "markdown", - "id": "33598839", + "id": "2a24c70a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

Task 2.2 Visualizing the results

\n", - "\n", - "The code above creates a small widget to interact with the results of this analysis. Look through the samples for a while before answering the questions below.\n", - "
" + "Something is slightly unfair about this visualization though.\n", + "We are visualizing as if it were grayscale, but both our images and our attributions are in color!\n", + "Can we learn more from the attributions if we visualize them in color?" ] }, { "cell_type": "code", "execution_count": null, - "id": "490db899", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "6e875faa", + "metadata": {}, "outputs": [], "source": [ - "from ipywidgets import interact\n", + "def visualize_color_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + "\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n", + " ax1.imshow(original_image)\n", + " ax1.set_title(\"Image\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()\n", "\n", - "interact(\n", - " interactive_attribution,\n", - " idx=(0, dataloader.batch_size - 1),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "18dce2c2", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "HELP! I Can't see any interactive setup!!\n", "\n", - "I got you... just uncomment the next cell and run it to see all of the samples at once." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eda303d1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# HELP! I Can't see any interative setup!!!\n", - "# for attr, im in zip(attributions, x.cpu().numpy()):\n", - "# visualize_attribution(smooth_attribution(attr), im)" + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "09cc4c08", + "id": "3f73608f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
\n", - "

Questions

\n", + "We get some better clues when looking at the attributions in color.\n", + "The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image.\n", + "Just based on this, however, we don't get much more information than we got from the images themselves.\n", "\n", - "- Are there some recognisable objects or parts of the synapse that show up in several examples?\n", - "- Are there some objects that seem secondary because they are less strongly highlighted?\n", - "\n", - "Tell us what you see on the chat!\n", - "
" + "If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier." ] }, { "cell_type": "markdown", - "id": "bd34722b", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", - "### Changing the basline\n", + "### Changing the baseline\n", "\n", "Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*.\n", "The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output.\n", @@ -1111,7 +421,7 @@ "```\n", "To get more details about how to include the baseline.\n", "\n", - "Try using the code above to change the baseline and see how this affects the output.\n", + "Try using the code below to change the baseline and see how this affects the output.\n", "\n", "1. Random noise as a baseline\n", "2. A blurred/noisy version of the original image as a baseline." @@ -1119,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "53feb16f", + "id": "dbb04b6f", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -1131,13 +441,11 @@ { "cell_type": "code", "execution_count": null, - "id": "9d6c65e1", + "id": "2fc8f45c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -1147,18 +455,15 @@ "attributions_random = integrated_gradients.attribute(...) # TODO Change\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", " visualize_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "e97700bc", + "id": "bf7e934c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1171,13 +476,11 @@ { "cell_type": "code", "execution_count": null, - "id": "b9e5b23e", + "id": "2e14f754", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -1189,38 +492,35 @@ "attributions_blurred = integrated_gradients.attribute(...) # TODO Fill\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "5cdde305", + "id": "db46361b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ "

Questions

\n", - "\n", - "- Are any of the features consistent across baselines? Why do you think that is?\n", - "- What baseline do you like best so far? Why?\n", - "- If you were to design an ideal baseline, what would you choose?\n", + "
    \n", + "
  • What baseline do you like best so far? Why?
  • \n", + "
  • Why do you think some baselines work better than others?
  • \n", + "
  • If you were to design an ideal baseline, what would you choose?
  • \n", + "
\n", "
" ] }, { "cell_type": "markdown", - "id": "1a15cf83", + "id": "e9105812", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", "\n", "\n", - "\n", "[`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms.\n", "\n", "Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other?\n", @@ -1229,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "9bb8d816", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -1237,28 +537,27 @@ "\n", "At this point we have:\n", "\n", - "- Trained a classifier that can predict neurotransmitters from EM-slices of synapses.\n", - "- Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients.\n", + "- Loaded a classifier that classifies MNIST-like images by color, but we don't know how!\n", + "- Tried applying Integrated Gradients to find out what the classifier is looking at - with little success.\n", "- Discovered the effect of changing the baseline on the output of integrated gradients.\n", "\n", + "Coming up in the next section, we will learn how to create counterfactual images.\n", + "These images will change *only what is necessary* in order to change the classification of the image.\n", + "We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature.\n", "
" ] }, { "cell_type": "markdown", - "id": "a31ef8d6", + "id": "531169e5", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations.\n", + "To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", "\n", @@ -1273,1640 +572,1141 @@ "\n", "**Counterfactual synapses**\n", "\n", - "In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below)." + "In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "9089850c", + "cell_type": "markdown", + "id": "331e56d6", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "def class_dir(name):\n", - " return f\"{class_to_idx[name]}_{name}\"\n", + "### The model\n", + "![stargan.png](assets/stargan.png)\n", + "\n", + "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", + "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", + "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", + "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", - "classes = [\"gaba\", \"acetylcholine\"]" + "Let's start by creating these!" ] }, { - "cell_type": "markdown", - "id": "36b89586", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "301ee289", + "metadata": {}, + "outputs": [], "source": [ - "## Training a GAN\n", + "from dlmbl_unet import UNet\n", + "from torch import nn\n", + "\n", "\n", - "Yes, really!" + "class Generator(nn.Module):\n", + "\n", + " def __init__(self, generator, style_encoder):\n", + " super().__init__()\n", + " self.generator = generator\n", + " self.style_encoder = style_encoder\n", + "\n", + " def forward(self, x, y):\n", + " \"\"\"\n", + " x: torch.Tensor\n", + " The source image\n", + " y: torch.Tensor\n", + " The style image\n", + " \"\"\"\n", + " style = self.style_encoder(y)\n", + " # Concatenate the style vector with the input image\n", + " style = style.unsqueeze(-1).unsqueeze(-1)\n", + " style = style.expand(-1, -1, x.size(2), x.size(3))\n", + " x = torch.cat([x, style], dim=1)\n", + " return self.generator(x)" ] }, { "cell_type": "markdown", - "id": "aff1b90b", + "id": "4ce023f6", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ - "

Creating a specialized dataset

\n", - "\n", - "The CycleGAN works on only 2 classes at a time, but our full dataset has 6. Below, we will use the `Subset` dataset from `torch.utils.data` to get the data from these two classes.\n", - "\n", - "A `Subset` is created as follows:\n", - "```\n", - "subset = Subset(dataset, indices)\n", - "```\n", - "\n", - "And the chosen indices can be obtained again using `subset.indices`.\n", + "

Task 3.1: Create the models

\n", "\n", - "Run the cell below to generate the datasets:\n", - "- `gan_train_dataset`\n", - "- `gan_test_dataset`\n", - "- `gan_val_dataset`\n", + "We are going to create the models for the generator, discriminator, and style mapping.\n", "\n", - "We will use them below to train the CycleGAN.\n", - "
" + "Given the Generator structure above, fill in the missing parts for the unet and the style mapping." ] }, { "cell_type": "code", "execution_count": null, - "id": "a8981d1e", + "id": "c2698719", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ - "# Utility functions to get a subset of classes\n", - "def get_indices(dataset, classes):\n", - " \"\"\"Get the indices of elements of classA and classB in the dataset.\"\"\"\n", - " indices = []\n", - " for cl in classes:\n", - " indices.append(torch.tensor(dataset.targets) == class_to_idx[cl])\n", - " logical_or = sum(indices) > 0\n", - " return torch.where(logical_or)[0]\n", - "\n", - "\n", - "def set_intersection(a_indices, b_indices):\n", - " \"\"\"Get intersection of two sets\n", - "\n", - " Parameters\n", - " ----------\n", - " a_indices: torch.Tensor\n", - " b_indices: torch.Tensor\n", - "\n", - " Returns\n", - " -------\n", - " intersection: torch.Tensor\n", - " The elements contained in both a_indices and b_indices.\n", - " \"\"\"\n", - " a_cat_b, counts = torch.cat([a_indices, b_indices]).unique(return_counts=True)\n", - " intersection = a_cat_b[torch.where(counts.gt(1))]\n", - " return intersection\n", - "\n", - "\n", - "# Getting training, testing, and validation indices\n", - "gan_idx = get_indices(full_dataset, classes)\n", - "\n", - "gan_train_idx = set_intersection(torch.tensor(train_dataset.indices), gan_idx)\n", - "gan_test_idx = set_intersection(torch.tensor(test_dataset.indices), gan_idx)\n", - "gan_val_idx = set_intersection(torch.tensor(validation_dataset.indices), gan_idx)\n", - "\n", - "# Checking that the subsets are complete\n", - "assert len(gan_train_idx) + len(gan_test_idx) + len(gan_val_idx) == len(gan_idx)\n", - "\n", - "# Generate three datasets based on the above indices.\n", - "from torch.utils.data import Subset\n", + "style_size = ... # TODO choose a size for the style space\n", + "unet_depth = ... # TODO Choose a depth for the UNet\n", + "style_encoder = DenseModel(\n", + " input_shape=..., num_classes=... # How big is the style space?\n", + ")\n", + "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", "\n", - "gan_train_dataset = Subset(full_dataset, gan_train_idx)\n", - "gan_test_dataset = Subset(full_dataset, gan_test_idx)\n", - "gan_val_dataset = Subset(full_dataset, gan_val_idx)" + "generator = Generator(unet, style_encoder=style_encoder)" ] }, { "cell_type": "markdown", - "id": "479b5de4", + "id": "16f87104", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "### The model\n", - "\n", - "![cycle.png](assets/cyclegan.png)\n", - "\n", - "In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine).\n", - "\n", - "It has two generators:\n", - " - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`.\n", - " - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`.\n", - "\n", - "\n", - "When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators.\n", - "For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA.\n", - "This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image.\n", - "\n", - "But how do we force the generators to change the class of the input image? We use a discriminator for each.\n", - " - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it.\n", - " - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it." + "

Hyper-parameter choices

\n", + "
    \n", + "
  • Are any of the hyperparameters you choose above constrained in some way?
  • \n", + "
  • What would happen if you chose a depth of 10 for the UNet?
  • \n", + "
  • Is there a minimum size for the style space? Why or why not?
  • \n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d308b66b", + "cell_type": "markdown", + "id": "9f1d1149", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "from torch import nn\n", - "import functools\n", - "from cycle_gan.models.networks import ResnetGenerator, NLayerDiscriminator, GANLoss\n", - "\n", - "\n", - "class CycleGAN(nn.Module):\n", - " \"\"\"Cycle GAN\n", + "

Task 3.2: Create the discriminator

\n", "\n", - " Has:\n", - " - Two class names\n", - " - Two Generators\n", - " - Two Discriminators\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self, class1, class2, input_nc=1, output_nc=1, ngf=64, ndf=64, use_dropout=False\n", - " ):\n", - " \"\"\"\n", - " class1: str\n", - " Label of the first class\n", - " class2: str\n", - " Label of the second class\n", - " \"\"\"\n", - " super().__init__()\n", - " norm_layer = functools.partial(\n", - " nn.InstanceNorm2d, affine=False, track_running_stats=False\n", - " )\n", - " self.classes = [class1, class2]\n", - " self.inverse_keys = {\n", - " class1: class2,\n", - " class2: class1,\n", - " } # i.e. what is the other key?\n", - " self.generators = nn.ModuleDict(\n", - " {\n", - " classname: ResnetGenerator(\n", - " input_nc,\n", - " output_nc,\n", - " ngf,\n", - " norm_layer=norm_layer,\n", - " use_dropout=use_dropout,\n", - " n_blocks=9,\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - " self.discriminators = nn.ModuleDict(\n", - " {\n", - " classname: NLayerDiscriminator(\n", - " input_nc, ndf, n_layers=3, norm_layer=norm_layer\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - "\n", - " def forward(self, x, train=True):\n", - " \"\"\"Creates fakes from the reals.\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " train: boolean\n", - " If false, only the counterfactuals are generated and returned.\n", - " Defaults to True.\n", - "\n", - " Returns\n", - " -------\n", - " fakes: dict\n", - " classname -> images of counterfactual images\n", - " identities: dict\n", - " classname -> images of images passed through their corresponding generator, if train is True\n", - " For example, images of class1 are passed through the generator that creates class1.\n", - " These should be identical to the input.\n", - " Not returned if `train` is `False`\n", - " reconstructions\n", - " classname -> images of reconstructed images (full cycle), if train is True.\n", - " Not returned if `train` is `False`\n", - " \"\"\"\n", - " fakes = {}\n", - " identities = {}\n", - " reconstructions = {}\n", - " for k, batch in x.items():\n", - " inv_k = self.inverse_keys[k]\n", - " # Counterfactual: class changes\n", - " fakes[inv_k] = self.generators[inv_k](batch)\n", - " if train:\n", - " # From counterfactual back to original, class changes again\n", - " reconstructions[k] = self.generators[k](fakes[inv_k])\n", - " # Identites: class does not change\n", - " identities[k] = self.generators[k](batch)\n", - " if train:\n", - " return fakes, identities, reconstructions\n", - " return fakes\n", - "\n", - " def discriminate(self, x):\n", - " \"\"\"Get discriminator opinion on x\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " \"\"\"\n", - " discrimination = {}\n", - " for k, batch in x.items():\n", - " discrimination[k] = self.discriminators[k](batch)\n", - " return discrimination" + "We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.\n", + "The discriminator will take as input either a real image or a fake image.\n", + "Fill in the following code to create a discriminator that can classify the images into the correct number of classes.\n", + "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "09c3fa55", + "id": "14e0c929", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ - "cyclegan = CycleGAN(*classes)\n", - "cyclegan.to(device)\n", - "print(f\"Will use device {device} for training\")" + "discriminator = DenseModel(input_shape=..., num_classes=...)" ] }, { "cell_type": "markdown", - "id": "f91db612", + "id": "231a5202", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "You will notice above that the `CycleGAN` takes an input in the form of a dictionary, but our datasets and data-loaders return the data in the form of two tensors. Below are two utility functions that will swap from data from one to the other." + "Let's move all models onto the GPU" ] }, { "cell_type": "code", "execution_count": null, - "id": "b6d5d5ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "c0a2d54d", + "metadata": {}, "outputs": [], "source": [ - "# Utility function to go to/from dictionaries/x,y tensors\n", - "def get_as_xy(dictionary):\n", - " x = torch.cat([arr for arr in dictionary.values()])\n", - " y = []\n", - " for k, v in dictionary.items():\n", - " val = class_labels[k]\n", - " y += [\n", - " val,\n", - " ] * len(v)\n", - " y = torch.Tensor(y).to(x.device)\n", - " return x, y\n", - "\n", - "\n", - "def get_as_dictionary(x, y):\n", - " dictionary = {}\n", - " for k in classes:\n", - " val = class_to_idx[k]\n", - " # Get all of the indices for this class\n", - " this_class_indices = torch.where(y == val)\n", - " dictionary[k] = x[this_class_indices]\n", - " return dictionary" + "generator = generator.to(device)\n", + "discriminator = discriminator.to(device)" ] }, { "cell_type": "markdown", - "id": "8d48e4af", + "id": "4540ef18", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ + "## Training a GAN\n", "\n", - "### Creating a training loop\n", - "\n", - "Now that we have a model, our next task is to create a training loop for the CycleGAN. This is a bit more difficult than the training loop for our classifier.\n", - "\n", - "Here are some of the things to keep in mind during the next task.\n", - "\n", - "1. The CycleGAN is (obviously) a GAN: a Generative Adversarial Network. What makes an adversarial network \"adversarial\" is that two different networks are working against each other. The loss that is used to optimize this is in our exercise `criterionGAN`. Although the specifics of this loss is beyond the score of this notebook, the idea is simple: the `criterionGAN` compares the output of the discriminator to a boolean-valued target. If we want the discriminator to think that it has seen a real image, we set the target to `True`. If we want the discriminator to think that it has seen a generated image, we set the target to `False`. Note that it isn't important here whether the image *is* real, but **whether we want the discriminator to think it is real at that point**. (This will be important very soon 😉)\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", "\n", - "2. Since the two networks are fighting each other, it is important to make sure that neither of them can cheat with information espionage. The CycleGAN implementation below is a turn-by-turn fight: we train the generator(s) and the discriminator(s) in alternating steps. When a model is not training, we will restrict its access to information by using `set_requries_grad` to `False`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8482184f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from cycle_gan.util.image_pool import ImagePool" + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "53c14194", + "id": "b9fc6671", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "criterionIdt = nn.L1Loss()\n", - "criterionCycle = nn.L1Loss()\n", - "criterionGAN = GANLoss(\"lsgan\")\n", - "criterionGAN.to(device)\n", - "\n", - "lambda_idt = 1\n", - "pool_size = 32\n", - "\n", - "lambdas = {k: 1 for k in classes}\n", - "image_pools = {classname: ImagePool(pool_size) for classname in classes}\n", - "\n", - "optimizer_g = torch.optim.Adam(cyclegan.generators.parameters(), lr=1e-4)\n", - "optimizer_d = torch.optim.Adam(cyclegan.discriminators.parameters(), lr=1e-4)" + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "706a5f18", + "id": "196daf45", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

Task 3.1: Set up the training losses and gradients

\n", "\n", - "In the code below, there are several spots with multiple options. Choose from among these, and delete or comment out the incorrect option.\n", - "1. In `generator_step`: Choose whether the target to the`criterionGAN` should be `True` or `False`.\n", - "2. In `discriminator_step`: Choose the target to the `criterionGAN` (note that there are two this time, one for the real images and one for the generated images)\n", - "3. In `train_gan`: `set_requires_grad` correctly.\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
\n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "9d36c59f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "1e9ddd12", + "metadata": {}, "outputs": [], "source": [ - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " # loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " # loss_real = criterionGAN(preds_real[k], True)\n", - " # loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "30b90f36", + "id": "eade7df1", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "a6e2d5a8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "1deb8b8b", + "metadata": {}, "outputs": [], "source": [ - "def plot_gan_output(sample=None):\n", - " # Get the input from the test dataset\n", - " if sample is None:\n", - " i = np.random.randint(len(gan_test_dataset))\n", - " x, y = gan_test_dataset[i]\n", - " x = x.to(device)\n", - " reals = {classes[y]: x}\n", - " else:\n", - " reals = sample\n", - "\n", - " with torch.no_grad():\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " for i in range(len(reals[k])):\n", - " fig, (ax, ax_fake, ax_id, ax_recon) = plt.subplots(1, 4)\n", - " ax.imshow(reals[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_fake.imshow(fakes[inv_k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_id.imshow(identities[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_recon.imshow(reconstructions[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " # Name the axes\n", - " ax.set_title(f\"{k.capitalize()}\")\n", - " ax_fake.set_title(\"Counterfactual\")\n", - " ax_id.set_title(\"Identity\")\n", - " ax_recon.set_title(\"Reconstruction\")\n", - " for ax in [ax, ax_fake, ax_id, ax_recon]:\n", - " ax.axis(\"off\")" + "cycle_loss_fn = nn.L1Loss()" ] }, { "cell_type": "markdown", - "id": "519aba30", + "id": "ba4a7f7f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", - "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
" + "To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`." ] }, { "cell_type": "code", "execution_count": null, - "id": "597f44ce", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "b5b3d5dc", + "metadata": {}, "outputs": [], "source": [ - "# Get a balanced sampler that only considers the two classes\n", - "sampler = balanced_sampler(gan_train_dataset)\n", + "from torch.utils.data import DataLoader\n", + "\n", "dataloader = DataLoader(\n", - " gan_train_dataset, batch_size=8, drop_last=True, sampler=sampler\n", - ")" + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "7370994c", + "cell_type": "markdown", + "id": "a029e923", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "# Number of iterations to train for (note: this is not *nearly* enough to get ideal results)\n", - "iterations = 500\n", - "# Determines how often to plot outputs to see how the network is doing. I recommend scaling your `print_every` to your `iterations`.\n", - "# For example, if you're running `iterations=5` you can `print_every=1`, but `iterations=1000` and `print_every=1` will be a lot of prints.\n", - "print_every = 100" + "As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.\n", + "Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!\n", + "`set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`)." ] }, { "cell_type": "code", "execution_count": null, - "id": "861dedd4", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "54b4de87", + "metadata": {}, "outputs": [], "source": [ - "for i in tqdm(range(iterations)):\n", - " x, y = next(iter(dataloader))\n", - " x = x.to(device)\n", - " y = y.to(device)\n", - " real = get_as_dictionary(x, y)\n", - " train_gan(real)\n", - " if i % print_every == 0:\n", - " cyclegan.eval() # Set to eval to speed up the plotting\n", - " plot_gan_output(sample=real)\n", - " cyclegan.train() # Set back to train!\n", - " plt.show()" + "def set_requires_grad(module, value=True):\n", + " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = value" ] }, { "cell_type": "markdown", - "id": "09c3f362", + "id": "014e484e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "...this time again.\n", + "Another consequence of adversarial training is that it is very unstable.\n", + "While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.\n", + "To force some stability back into the training, we will use Exponential Moving Averages (EMA).\n", "\n", - "\"drawing\"" + "In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update.\n", + "A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period.\n", + "Each epoch, we will then copy the EMA model's weights back to the generator.\n", + "This is a common technique used in GAN training to stabilize the training process.\n", + "Pay attention to what this does to the loss during the training process!" ] }, { - "cell_type": "markdown", - "id": "6ee205dd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "f6344c83", + "metadata": {}, + "outputs": [], "source": [ - "

Checkpoint 3

\n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", - "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", + "from copy import deepcopy\n", "\n", - "You know the drill... let us know on the exercise chat!\n", - "
" + "\n", + "def exponential_moving_average(model, ema_model, beta=0.999):\n", + " \"\"\"Update the EMA model's parameters with an exponential moving average\"\"\"\n", + " for param, ema_param in zip(model.parameters(), ema_model.parameters()):\n", + " ema_param.data.mul_(beta).add_((1 - beta) * param.data)\n", + "\n", + "\n", + "def copy_parameters(source_model, target_model):\n", + " \"\"\"Copy the parameters of a model to another model\"\"\"\n", + " for param, target_param in zip(\n", + " source_model.parameters(), target_model.parameters()\n", + " ):\n", + " target_param.data.copy_(param.data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08b7b3af", + "metadata": {}, + "outputs": [], + "source": [ + "generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))\n", + "generator_ema = generator_ema.to(device)" ] }, { "cell_type": "markdown", - "id": "765089a1", + "id": "23fbf680", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "

Task 3.3: Training!

\n", + "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", + "Comment out the option that you think will not work.\n", + "
    \n", + "
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • \n", + "
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • \n", + "
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • \n", + ".
  • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
  • \n", + "
\n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
" ] }, { "cell_type": "markdown", - "id": "8959c219", + "id": "9cb8281d", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "\n", - "## That was fun!... let's load a pre-trained model\n", - "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋" ] }, { "cell_type": "code", "execution_count": null, - "id": "0fd97600", + "id": "3b01306d", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", - "\n", + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "def load_pretrained(model, path, classA, classB):\n", - " \"\"\"Load the pre-trained models from the path\"\"\"\n", - " directory = Path(path).expanduser() / f\"{classA}_{classB}\"\n", - " # Load generators\n", - " model.generators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_A.pth\")\n", - " )\n", - " model.generators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_B.pth\")\n", - " )\n", - " # Load discriminators\n", - " model.discriminators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_A.pth\")\n", - " )\n", - " model.discriminators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_B.pth\")\n", - " )\n", "\n", + "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", "\n", - "load_pretrained(cyclegan, \"./checkpoints/synapses/cycle_gan/\", *classes)" - ] - }, - { - "cell_type": "markdown", - "id": "ee456f57", + "for epoch in range(15):\n", + " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " # get the target y by shuffling the classes\n", + " # get the style sources by random sampling\n", + " random_index = torch.randperm(len(y))\n", + " x_style = x[random_index].clone()\n", + " y_target = y[random_index].clone()\n", + "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + "\n", + " optimizer_g.zero_grad()\n", + " # Get the fake image\n", + " x_fake = generator(x, x_style)\n", + " # Try to cycle back\n", + " x_cycled = generator(x_fake, x)\n", + " # Discriminate\n", + " discriminator_x_fake = discriminator(x_fake)\n", + " # Losses to train the generator\n", + "\n", + " # 1. make sure the image can be reconstructed\n", + " cycle_loss = cycle_loss_fn(x, x_cycled)\n", + " # 2. make sure the discriminator is fooled\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", + "\n", + " # Optimize the generator\n", + " (cycle_loss + adv_loss).backward()\n", + " optimizer_g.step()\n", + "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + " #\n", + " optimizer_d.zero_grad()\n", + " #\n", + " discriminator_x = discriminator(x)\n", + " discriminator_x_fake = discriminator(x_fake.detach())\n", + "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " # Losses to train the discriminator\n", + " # 1. make sure the discriminator can tell real is real\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", + " #\n", + " disc_loss = (real_loss + fake_loss) * 0.5\n", + " disc_loss.backward()\n", + " # Optimize the discriminator\n", + " optimizer_d.step()\n", + "\n", + " losses[\"cycle\"].append(cycle_loss.item())\n", + " losses[\"adv\"].append(adv_loss.item())\n", + " losses[\"disc\"].append(disc_loss.item())\n", + "\n", + " # EMA update\n", + " # TODO - perform the EMA update\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " exponential_moving_average(generator, generator_ema)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " exponential_moving_average(generator_ema, generator)\n", + " # Copy the EMA model's parameters to the generator\n", + " copy_parameters(generator_ema, generator)" + ] + }, + { + "cell_type": "markdown", + "id": "4c25819b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Once training is complete, we can plot the losses to see how well the model is doing." ] }, { "cell_type": "code", "execution_count": null, - "id": "20adc855", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "0d64d32d", + "metadata": {}, "outputs": [], "source": [ - "for i in range(5):\n", - " plot_gan_output()" + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "dfa1b783", + "id": "326ba2b5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "We're going to apply the CycleGAN to our test dataset, and save the results to be reused later." + "

Questions

\n", + "
    \n", + "
  • Do the losses look like what you expected?
  • \n", + "
  • How do these losses differ from the losses you would expect from a classifier?
  • \n", + "
  • Based only on the losses, do you think the model is doing well?
  • \n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0887b0da", + "cell_type": "markdown", + "id": "3e58ca01", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "dataloader = DataLoader(gan_test_dataset, batch_size=32)" + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "67b7c1e8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "1c522efa", + "metadata": {}, "outputs": [], "source": [ - "from skimage.io import imsave\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return ((0.5 * image + 0.5) * 255).astype(np.uint8)\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def apply_gan(dataloader, directory):\n", - " \"\"\"Run CycleGAN on a dataloader and save images to a directory.\"\"\"\n", - " directory = Path(directory)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " cyclegan.eval()\n", - " batch_size = dataloader.batch_size\n", - " n_sample = 0\n", - " for batch, (x, y) in enumerate(tqdm(dataloader)):\n", - " reals = get_as_dictionary(x.to(device), y.to(device))\n", - " fakes, _, recons = cyclegan(reals)\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " (directory / f\"real/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"reconstructed/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"counterfactual/{k}\").mkdir(parents=True, exist_ok=True)\n", - " for i, (im_real, im_fake, im_recon) in enumerate(\n", - " zip(reals[k], fakes[inv_k], recons[k])\n", - " ):\n", - " # Save real synapse images\n", - " imsave(\n", - " directory / f\"real/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_real.cpu().numpy().squeeze()),\n", - " )\n", - " # Save fake synapse images\n", - " imsave(\n", - " directory / f\"reconstructed/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_recon.cpu().numpy().squeeze()),\n", - " )\n", - " # Save counterfactual synapse images\n", - " imsave(\n", - " directory / f\"counterfactual/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_fake.cpu().numpy().squeeze()),\n", - " )\n", - " # Count\n", - " n_sample += 1\n", - " return" + "idx = 0\n", + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[0].set_title(\"Input image\")\n", + "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].set_title(\"Style image\")\n", + "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].set_title(\"Generated image\")\n", + "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].set_title(\"Cycled image\")\n", + "\n", + "for ax in axs:\n", + " ax.axis(\"off\")\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "0b4bfcf0", + "id": "30b6dac9", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "a3ecbc7b", + "metadata": { + "tags": [] + }, "source": [ - "apply_gan(dataloader, \"test_images/\")" + "

Checkpoint 3

\n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", + "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", + "\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "2eb0e50e", + "cell_type": "markdown", + "id": "e6bdaecb", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, - "outputs": [], "source": [ - "# Clean-up the gpu's memory a bit to avoid Out-of-Memory errors\n", - "cyclegan = cyclegan.cpu()\n", - "torch.cuda.empty_cache()" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "483af604", + "id": "7f994579", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "## Evaluating the GAN\n", + "## Creating counterfactuals\n", "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "The data were saved in a directory called `test_images`.\n" + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "c59702f9", + "id": "4e4fe83e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "def make_dataset(directory):\n", - " \"\"\"Create a dataset from a directory of images with the classes in the same order as the VGG's output.\n", - "\n", - " Parameters\n", - " ----------\n", - " directory: str\n", - " The root directory of the images. It should contain sub-directories named after the classes, in which images are stored.\n", - " \"\"\"\n", - " # Make a dataset with the classes in the correct order\n", - " limited_classes = {k: v for k, v in class_to_idx.items() if k in classes}\n", - " dataset = ImageFolder(root=directory, transform=transform)\n", - " samples = ImageFolder.make_dataset(\n", - " directory, class_to_idx=limited_classes, extensions=\".png\"\n", - " )\n", - " # Sort samples by name\n", - " samples = sorted(samples, key=lambda s: s[0].split(\"_\")[-1])\n", - " dataset.classes = classes\n", - " dataset.class_to_idx = limited_classes\n", - " dataset.samples = samples\n", - " dataset.targets = [s[1] for s in samples]\n", - " return dataset" + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", + "prototypes = {}\n", + "\n", + "\n", + "for i in range(4):\n", + " options = np.where(test_mnist.conditions == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "c6bffc67", + "id": "049a6b22", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

Task 4.1 Get the classifier accuracy on CycleGAN outputs

\n", - "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", - "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", - "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
" + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "42906ce7", + "id": "639f37e2", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" + ] + }, + { + "cell_type": "markdown", + "id": "02cb705b", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "# Dataset of real images\n", - "ds_real = ...\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = ...\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = ..." + "Now we need to use these prototypes to create counterfactual images!" ] }, { "cell_type": "markdown", - "id": "c4500183", + "id": "f41a6ce5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
\n", - "We get the following accuracies:\n", + "

Task 4: Create counterfactuals

\n", + "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", - "\n", - "

Questions

\n", - "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
" + "
    \n", + "
  • Create a counterfactual image for each of the prototypes.
  • \n", + "
  • Classify the counterfactual image using the classifier.
  • \n", + "
  • Store the source and target labels; which is which?
  • \n", + "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "17b2af0c", + "id": "282f8858", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ - "cf_pred, cf_gt = predict(ds_counterfactual, \"Counterfactuals\")\n", - "recon_pred, recon_gt = predict(ds_recon, \"Reconstructions\")\n", - "real_pred, real_gt = predict(ds_real, \"Real images\")\n", + "num_images = 1000\n", + "random_test_mnist = torch.utils.data.Subset(\n", + " test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)\n", + ")\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "# Get the accuracies\n", - "accuracy_real = accuracy_score(real_gt, real_pred)\n", - "accuracy_recon = accuracy_score(recon_gt, recon_pred)\n", - "accuracy_cf = accuracy_score(cf_gt, cf_pred)\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "print(\n", - " f\"Accuracy real: {accuracy_real}\\nAccuracy reconstruction: {accuracy_recon}\\nAccuracy counterfactuals: {accuracy_cf}\\n\"\n", - ")" + "for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):\n", + " for lbl in range(4):\n", + " # TODO Create the counterfactual\n", + " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", + " # TODO Predict the class of the counterfactual image\n", + " pred = model(...)\n", + "\n", + " # TODO Store the source and target labels\n", + " source_labels.append(...) # The original label of the image\n", + " target_labels.append(...) # The desired label of the counterfactual image\n", + " # Store the counterfactual image and prediction\n", + " counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()\n", + " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "615c9449", + "id": "ebffc15f", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "4c0e1278", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "baac8071", + "metadata": {}, "outputs": [], "source": [ - "labels = [class_to_idx[i] for i in classes]\n", - "print(\"The confusion matrix of the classifier on the counterfactuals\")\n", - "cm_analysis(cf_pred, cf_gt, names=classes, labels=labels)" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "92401b45", + "cell_type": "markdown", + "id": "88e7ea0c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "cm_analysis(real_pred, real_gt, names=classes, labels=labels)" + "

Questions

\n", + "
    \n", + "
  • How well is our GAN doing at creating counterfactual images?
  • \n", + "
  • Does your choice of prototypes matter? Why or why not?
  • \n", + "
\n", + "
" ] }, { "cell_type": "markdown", - "id": "57f8cca6", + "id": "25972c49", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", - "\n", - "
" + "Let's also plot some examples of the counterfactual images." ] }, { - "cell_type": "markdown", - "id": "d81bbc95", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "12d49576", + "metadata": {}, + "outputs": [], "source": [ - "

Checkpoint 4

\n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", - "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
" + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "406e8777", + "id": "8e6f04f3", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Part 5: Highlighting Class-Relevant Differences" + "

Questions

\n", + "
    \n", + "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", + "
  • What is your hypothesis for the features that define each class?
  • \n", + "
\n", + "
" ] }, { "cell_type": "markdown", - "id": "69ee980b", - "metadata": {}, + "id": "50728ff2", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ "At this point we have:\n", - "- A classifier that can differentiate between neurotransmitters from EM images of synapses\n", - "- A vague idea of which parts of the images it thinks are important for this classification\n", - "- A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes\n", + "- A classifier that can differentiate between image of different classes\n", + "- A GAN that has correctly figured out how to change the class of an image\n", "\n", - "What we don't know, is *how* the CycleGAN is modifying the images to change their class.\n", - "\n", - "To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier." - ] - }, - { - "cell_type": "markdown", - "id": "f7dbe347", - "metadata": {}, - "source": [ - "

Task 5.1 Get sucessfully converted samples

\n", - "The CycleGAN is able to convert some, but not all images into their target types.\n", - "In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses:\n", - "
    \n", - "
  1. That were correctly classified originally
  2. \n", - "
  3. Whose counterfactuals were also correctly classified
  4. \n", - "
\n", - "\n", - "TODO\n", - "- Get a boolean description of the `real` samples that were correctly predicted\n", - "- Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!)\n", - "- Get a boolean description of the `cf` samples that have the target class\n", - "
" + "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "28ec78be", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "dedc0f83", + "metadata": {}, "outputs": [], "source": [ - "####### Task 5.1 TODO #######\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = ...\n", + "batch_size = 4\n", + "batch = [random_test_mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", + "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x = x.to(device).float()\n", + "y = y.to(device)\n", + "x_fake = x_fake.to(device).float()\n", "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = ...\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = ...\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", - "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" - ] - }, - { - "cell_type": "markdown", - "id": "5518deea", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:" + "# Generated attributions on integrated gradients\n", + "attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)" ] }, { "cell_type": "code", "execution_count": null, - "id": "c813f006", + "id": "5446e796", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "title": "Another visualization function" }, "outputs": [], "source": [ - "model = model.to(\"cuda\")" + "def visualize_color_attribution_and_counterfactual(\n", + " attribution, original_image, counterfactual_image\n", + "):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + " counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))\n", + "\n", + " fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))\n", + " ax0.imshow(original_image)\n", + " ax0.set_title(\"Image\")\n", + " ax0.axis(\"off\")\n", + " ax1.imshow(counterfactual_image)\n", + " ax1.set_title(\"Counterfactual\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "d599f126", + "id": "5e2fb59e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "real_true, real_pred = predict(real_success_ds, \"Real\")\n", - "cf_true, cf_pred = predict(cf_success_ds, \"Counterfactuals\")\n", - "\n", - "print(\n", - " \"Accuracy of the classifier on successful real images\",\n", - " accuracy_score(real_true, real_pred),\n", - ")\n", - "print(\n", - " \"Accuracy of the classifier on successful counterfactual images\",\n", - " accuracy_score(cf_true, cf_pred),\n", - ")" + "for idx in range(batch_size):\n", + " print(\"Source class:\", y[idx].item())\n", + " print(\"Target class:\", 0)\n", + " visualize_color_attribution_and_counterfactual(\n", + " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", + " )" ] }, { "cell_type": "markdown", - "id": "877db1dc", + "id": "b393a8f1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "### Creating hybrids from attributions\n", - "\n", - "Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.\n", - "If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid!\n", - "\n", - "To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification." + "

Questions

\n", + "
    \n", + "
  • Do the attributions explain the differences between the images and their counterfactuals?
  • \n", + "
  • What happens when the \"counterfactual\" and the original image are of the same class? Why do you think this is?
  • \n", + "
  • Do you have a more refined hypothesis for what makes each class unique?
  • \n", + "
\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "dcb7288f", + "cell_type": "markdown", + "id": "5ba47fc6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "dataloader_real = DataLoader(real_success_ds, batch_size=10)\n", - "dataloader_counter = DataLoader(cf_success_ds, batch_size=10)" + "

Checkpoint 4

\n", + "At this point you have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "95239b4b", + "cell_type": "markdown", + "id": "2654d788", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "%%time\n", - "with torch.no_grad():\n", - " model.to(device)\n", - " # Create an integrated gradients object.\n", - " # integrated_gradients = IntegratedGradients(model)\n", - " # Generated attributions on integrated gradients\n", - " attributions = np.vstack(\n", - " [\n", - " integrated_gradients.attribute(\n", - " real.to(device),\n", - " target=target.to(device),\n", - " baselines=counterfactual.to(device),\n", - " )\n", - " .cpu()\n", - " .numpy()\n", - " for (real, target), (counterfactual, _) in zip(\n", - " dataloader_real, dataloader_counter\n", - " )\n", - " ]\n", - " )" + "# Part 5: Exploring the Style Space, finding the answer\n", + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", + "\n", + "Here is an example of two images that are very similar in color, but are of different classes.\n", + "![same_color_diff_class](assets/same_color_diff_class.png)\n", + "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", + "\n", + "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", + "![same_class_diff_color](assets/same_class_diff_color.png)\n", + "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", + "\n", + "\n", + "So color is important... but not always? What's going on!?\n", + "There is a final piece of information that we can use to solve the puzzle: the style space." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "8b968d7c", + "cell_type": "markdown", + "id": "76559366", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "

Task 5.1: Explore the style space

\n", + "Let's take a look at the style space.\n", + "We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "
" + ] }, { "cell_type": "code", "execution_count": null, - "id": "84835390", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "f1fdb890", + "metadata": {}, "outputs": [], "source": [ - "# Functions for creating an interactive visualization of our attributions\n", - "model.cpu()\n", - "\n", - "import matplotlib\n", - "\n", - "cmap = matplotlib.cm.get_cmap(\"viridis\")\n", - "colors = cmap([0, 255])\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def get_classifications(image, counter, hybrid):\n", - " model.eval()\n", - " class_idx = [full_dataset.classes.index(c) for c in classes]\n", - " tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float()\n", - " with torch.no_grad():\n", - " logits = model(tensor)[:, class_idx]\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " pred, counter_pred, hybrid_pred = probs\n", - " return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy()\n", - "\n", - "\n", - "def visualize_counterfactuals(idx, threshold=0.1):\n", - " image = real_success_ds[idx][0].numpy()\n", - " counter = cf_success_ds[idx][0].numpy()\n", - " mask = get_mask(attributions[idx], threshold)\n", - " hybrid = (1 - mask) * image + mask * counter\n", - " nan_mask = copy.deepcopy(mask)\n", - " nan_mask[nan_mask != 0] = 1\n", - " nan_mask[nan_mask == 0] = np.nan\n", - " # PLOT\n", - " fig, axes = plt.subplot_mosaic(\n", - " \"\"\"\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " ....ggg.fff.ppp\n", - " \"\"\",\n", - " figsize=(20, 5),\n", + "from sklearn.decomposition import PCA\n", + "\n", + "\n", + "styles = []\n", + "labels = []\n", + "for img, label in random_test_mnist:\n", + " styles.append(\n", + " style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()\n", " )\n", - " # Original\n", - " viz.visualize_image_attr(\n", - " np.transpose(mask, (1, 2, 0)),\n", - " np.transpose(image, (1, 2, 0)),\n", - " method=\"blended_heat_map\",\n", - " sign=\"absolute_value\",\n", - " show_colorbar=True,\n", - " title=\"Mask\",\n", - " use_pyplot=False,\n", - " plt_fig_axis=(fig, axes[\"m\"]),\n", + " labels.append(label)\n", + "\n", + "# PCA\n", + "pca = PCA(n_components=2)\n", + "styles_pca = pca.fit_transform(styles)\n", + "\n", + "# Plot the PCA\n", + "markers = [\"o\", \"s\", \"P\", \"^\"]\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", " )\n", - " # Original\n", - " axes[\"o\"].imshow(image.squeeze(), cmap=\"gray\")\n", - " axes[\"o\"].set_title(\"Original\", fontsize=24)\n", - " # Counterfactual\n", - " axes[\"c\"].imshow(counter.squeeze(), cmap=\"gray\")\n", - " axes[\"c\"].set_title(\"Counterfactual\", fontsize=24)\n", - " # Hybrid\n", - " axes[\"h\"].imshow(hybrid.squeeze(), cmap=\"gray\")\n", - " axes[\"h\"].set_title(\"Hybrid\", fontsize=24)\n", - " # Mask\n", - " pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid)\n", - " axes[\"g\"].barh(classes, pred, color=colors)\n", - " axes[\"f\"].barh(classes, counter_pred, color=colors)\n", - " axes[\"p\"].barh(classes, hybrid_pred, color=colors)\n", - " for ix in [\"m\", \"o\", \"c\", \"h\"]:\n", - " axes[ix].axis(\"off\")\n", - "\n", - " for ix in [\"g\", \"f\", \"p\"]:\n", - " for tick in axes[ix].get_xticklabels():\n", - " tick.set_rotation(90)\n", - " axes[ix].set_xlim(0, 1)" + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "c732d7a7", + "id": "b666769e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

Task 5.2: Observing the effect of the changes on the classifier

\n", - "Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.\n", - "At what point does it swap over?\n", + "

Task 5.1: Adding color to the style space

\n", + "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", - "If you want to see different samples, slide through the `idx`.\n", + "Let's use the style space to color the PCA plot.\n", + "(Note: there is no code to write here, just run the cell and answer the questions below)\n", "
" ] }, { "cell_type": "code", "execution_count": null, - "id": "23225866", + "id": "e61d0c9b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))" + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=normalized_styles[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "1ca835c5", - "metadata": {}, - "source": [ - "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "771fb28f", + "id": "6f1d3ff3", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "# Choose your own adventure\n", - "# idx = 0\n", - "# threshold = 0.1\n", - "\n", - "# # Plotting :)\n", - "# visualize_counterfactuals(idx, threshold)" + "

Questions

\n", + "
    \n", + "
  • Do the colors match those that you have seen in the data?
  • \n", + "
  • Can you see any patterns in the colors? Is the space smooth, for example?
  • \n", + "
" ] }, { "cell_type": "markdown", - "id": "3905e9a7", + "id": "90889399", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
\n", - "

Questions

\n", - "\n", - "- Can you find features that define either of the two classes?\n", - "- How consistent are they across the samples?\n", - "- Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)`\n", + "

Task 5.2: Using the images to color the style space

\n", + "Finally, let's just use the colors from the images themselves!\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", "\n", - "Feel free to discuss your answers on the exercise chat!\n", + "Let's get that color, then plot the style space again.\n", + "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", "
" ] }, { - "cell_type": "markdown", - "id": "578e5831", + "cell_type": "code", + "execution_count": null, + "id": "f67b3f90", + "metadata": {}, + "outputs": [], + "source": [ + "colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=colors[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b18b2b81", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "bf87e80b", + "metadata": {}, "source": [ - "
\n", - "

The End.

\n", - " Go forth and train some GANs!\n", - "
" + "

Questions

\n", + "
    \n", + "
  • Do the colors match those that you have seen in the data?
  • \n", + "
  • Can you see any patterns in the colors?
  • \n", + "
  • Can you guess what the classes correspond to?
  • " ] }, { "cell_type": "markdown", - "id": "2f8cb30e", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "11aafcc5", + "metadata": {}, "source": [ - "## Going Further\n", + "

    Checkpoint 5

    \n", + "Congratulations! You have made it to the end of the exercise!\n", + "You have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n", + "- Used the style space to understand the differences between classes\n", "\n", - "Here are some ideas for how to continue with this notebook:\n", - "\n", - "1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.\n", - " * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).\n", - " * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.\n", - " * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.\n", - " * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved.\n", - "\n", - "2. Explore the CycleGAN.\n", - " * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these?\n", - "\n", - "3. Try on your own data!\n", - " * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code." + "If you have any questions, feel free to ask them in the chat!\n", + "And check the Solutions exercise for a definite answer to how these classes are defined!" ] } ], "metadata": { "jupytext": { - "cell_metadata_filter": "all" - }, - "kernelspec": { - "display_name": "09_knowledge_extraction", - "language": "python", - "name": "python3" + "cell_metadata_filter": "all", + "main_language": "python" } }, "nbformat": 4, diff --git a/extras/train_classifier.py b/extras/train_classifier.py new file mode 100644 index 0000000..fcac3b8 --- /dev/null +++ b/extras/train_classifier.py @@ -0,0 +1,47 @@ +""" +This script was used to train the pre-trained model weights that were given as an option during the exercise. +""" + +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from pathlib import Path + + +def train_classifier(base_dir, epochs=10): + checkpoint_dir = Path(base_dir) / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + data_dir = Path(base_dir) / "data" + data_dir.mkdir(exist_ok=True) + # + model = DenseModel((28, 28, 3), 4) + data = ColoredMNIST(data_dir, download=True, train=True) + dataloader = DataLoader(data, batch_size=32, shuffle=True, pin_memory=True) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + loss_fn = torch.nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + losses = [] + for epoch in range(epochs): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + optimizer.zero_grad() + y_pred = model(x.to(device)) + loss = loss_fn(y_pred, y.to(device)) + loss.backward() + optimizer.step() + print(f"Epoch {epoch}: Loss = {loss.item()}") + losses.append(loss.item()) + # TODO save every epoch instead of overwriting? + torch.save(model.state_dict(), checkpoint_dir / "model.pth") + + with open(checkpoint_dir / "losses.txt", "w") as f: + f.write("\n".join(str(l) for l in losses)) + + +if __name__ == "__main__": + this_dir = Path(__file__).parent + train_classifier(base_dir=this_dir, epochs=10) diff --git a/extras/train_gan.py b/extras/train_gan.py new file mode 100644 index 0000000..d0628a3 --- /dev/null +++ b/extras/train_gan.py @@ -0,0 +1,175 @@ +from dlmbl_unet import UNet +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import torch +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from copy import deepcopy +import json +from pathlib import Path + + +class Generator(nn.Module): + def __init__(self, generator, style_mapping): + super().__init__() + self.generator = generator + self.style_mapping = style_mapping + + def forward(self, x, y): + """ + x: torch.Tensor + The source image + y: torch.Tensor + The style image + """ + style = self.style_mapping(y) + # Concatenate the style vector with the input image + style = style.unsqueeze(-1).unsqueeze(-1) + style = style.expand(-1, -1, x.size(2), x.size(3)) + x = torch.cat([x, style], dim=1) + return self.generator(x) + + +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + + +def exponential_moving_average(model, ema_model, beta=0.999): + """Update the EMA model's parameters with an exponential moving average""" + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(beta).add_((1 - beta) * param.data) + + +def copy_parameters(source_model, target_model): + """Copy the parameters of a model to another model""" + for param, target_param in zip( + source_model.parameters(), target_model.parameters() + ): + target_param.data.copy_(param.data) + + +if __name__ == "__main__": + save_dir = Path("checkpoints/stargan") + save_dir.mkdir(parents=True, exist_ok=True) + mnist = ColoredMNIST("../data", download=True, train=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + size_style = 8 + total_epochs = 14 + unet = UNet( + depth=2, + in_channels=3 + size_style, + out_channels=3, + final_activation=nn.Sigmoid(), + ) + discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) + style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=size_style) + generator = Generator(unet, style_mapping=style_mapping) + generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping)) + + # all models on the GPU + generator = generator.to(device) + generator_ema = generator_ema.to(device) + discriminator = discriminator.to(device) + + cycle_loss_fn = nn.L1Loss() + class_loss_fn = nn.CrossEntropyLoss() + + optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) + optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + + dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True + ) # We will use the same dataset as before + + # Load last existing checkpoint + epoch = 0 + checkpoints = sorted(save_dir.glob("checkpoint_*.pth")) + if len(checkpoints) > 0: + checkpoint = torch.load(checkpoints[-1]) + print(f"Resuming from checkpoint {checkpoints[-1]}") + unet.load_state_dict(checkpoint["unet"]) + discriminator.load_state_dict(checkpoint["discriminator"]) + style_mapping.load_state_dict(checkpoint["style_mapping"]) + optimizer_g.load_state_dict(checkpoint["optimizer_g"]) + optimizer_d.load_state_dict(checkpoint["optimizer_d"]) + epoch = ( + checkpoint["epoch"] + 1 + ) # Start from the next epoch since this checkpoint exists + + losses = {"cycle": [], "adv": [], "disc": []} + for epoch in range(epoch, total_epochs): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + # Set training gradients correctly + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = class_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + # Set training gradients correctly + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_d.zero_grad() + # Discriminate + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + real_loss = class_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can't tell fake is fake + fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + + # EMA update + exponential_moving_average(generator, generator_ema) + # TODO add logging, add checkpointing + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) + # Store checkpoint + torch.save( + { + "unet": unet.state_dict(), + "discriminator": discriminator.state_dict(), + "style_mapping": style_mapping.state_dict(), + "optimizer_g": optimizer_g.state_dict(), + "optimizer_d": optimizer_d.state_dict(), + "epoch": epoch, + }, + save_dir / f"checkpoint_{epoch}.pth", + ) + # Store losses + with open(save_dir / "losses.json", "w") as f: + json.dump(losses, f) diff --git a/extras/validate_classifier.py b/extras/validate_classifier.py new file mode 100644 index 0000000..1f2cedb --- /dev/null +++ b/extras/validate_classifier.py @@ -0,0 +1,47 @@ +""" +This script was used to validate the pre-trained classifier. +""" + +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def confusion_matrix(labels, predictions): + n_classes = len(set(labels)) + matrix = np.zeros((n_classes, n_classes)) + for label, pred in zip(labels, predictions): + matrix[label, pred] += 1 + return matrix + + +def validate_classifier(checkpoint_dir): + data = ColoredMNIST("../data", download=False, train=False) + dataloader = DataLoader( + data, batch_size=32, shuffle=False, pin_memory=True, drop_last=False + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = DenseModel((28, 28, 3), 4) + model.to(device) + model.load_state_dict(torch.load(f"{checkpoint_dir}/model.pth", weights_only=True)) + + labels = [] + predictions = [] + for x, y in tqdm(dataloader, desc=f"Validation"): + pred = model(x.to(device)) + pred_y = torch.argmax(pred, dim=1) + labels.extend(y.numpy()) + predictions.extend(pred_y.cpu().numpy()) + + # Get confusion matrix + matrix = confusion_matrix(labels, predictions) + # Save matrix as text + np.savetxt(f"{checkpoint_dir}/confusion_matrix.txt", matrix, fmt="%d") + + +if __name__ == "__main__": + validate_classifier(checkpoint_dir="checkpoints") diff --git a/extras/validate_gan.py b/extras/validate_gan.py new file mode 100644 index 0000000..63ff64c --- /dev/null +++ b/extras/validate_gan.py @@ -0,0 +1,74 @@ +# %% +from dlmbl_unet import UNet +from classifier.model import DenseModel +from classifier.data import ColoredMNIST +import torch +from torch import nn +import json +from pathlib import Path +from matplotlib import pyplot as plt +import numpy as np +from train_gan import Generator + +# %% +with open("checkpoints/stargan/losses.json", "r") as f: + losses = json.load(f) + +for key, value in losses.items(): + plt.plot(value, label=key) +plt.legend() + +# %% Plotting an example +# Load the data +mnist = ColoredMNIST("../data", download=True, train=False) + +# %% +# Create the model +style_size = 8 +epoch = 14 +unet = UNet( + depth=2, in_channels=3 + style_size, out_channels=3, final_activation=nn.Sigmoid() +) +style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=style_size) +# Load model weights +weights = torch.load(f"checkpoints/stargan/checkpoint_{epoch}.pth") +unet.load_state_dict(weights["unet"]) +style_encoder.load_state_dict(weights["style_mapping"]) # Change this to style encoder +generator = Generator(unet, style_encoder) +# %% +# Load one image from the dataset +x, y = mnist[0] +# Load one image from each other class +results = {} +for i in range(len(mnist.classes)): + if i == y: + continue + index = np.where(mnist.conditions == i)[0][0] + style = mnist[index][0] + # Generate the images + generated = generator(x.unsqueeze(0), style.unsqueeze(0)) + results[i] = (style, generated) +# Plot the images +source_style = mnist.classes[y] + +fig, axes = plt.subplots(2, 4, figsize=(12, 3)) +for i, (style, generated) in results.items(): + axes[0, i].imshow(style.permute(1, 2, 0)) + axes[0, i].set_title(mnist.classes[i]) + axes[0, i].axis("off") + axes[1, i].imshow(generated[0].detach().permute(1, 2, 0)) + axes[1, i].set_title(f"{mnist.classes[i]}") + axes[1, i].axis("off") + +# Plot real +axes[1, y].imshow(x.permute(1, 2, 0)) +axes[1, y].set_title(source_style) +axes[1, y].axis("off") +axes[0, y].axis("off") + +# %% +# TODO get prototype images for each class +# TODO convert every image in the dataset + classify result +# TODO plot a confusion matrix + +# %% diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b57e7e0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +git+https://github.com/adjavon/classification.git +ipykernel +tqdm +captum +git+https://github.com/dlmbl/dlmbl-unet.git +scikit-learn +seaborn \ No newline at end of file diff --git a/setup.sh b/setup.sh index 12ce1a1..22009e7 100755 --- a/setup.sh +++ b/setup.sh @@ -1,23 +1,15 @@ #!/usr/bin/env -S bash -i echo "Creating conda environment" -mamba env create -f environment.yaml +conda create -n 08_knowledge_extraction -y python=3.11 +eval "$(conda shell.bash hook)" +conda activate 08_knowledge_extraction +# Check if the environment is activated +echo "Environment activated: $(which python)" -# get the CycleGAN code and dependencies -git clone https://github.com/funkey/neuromatch_xai -mv neuromatch_xai/cycle_gan . -rm -rf neuromatch_xai +conda install -y pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia +pip install -r requirements.txt -# Download checkpoints and data -wget 'https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/knowledge_extraction_resources.zip' -O resources.zip -# Unzip the checkpoints and data -unzip -o resources.zip data.zip -unzip -o resources.zip checkpoints.zip -unzip -o checkpoints.zip 'checkpoints/synapses/*' -unzip -o data.zip 'data/raw/synapses/*' -# make sure the order of classes matches the pretrained model -mv data/raw/synapses/gaba data/raw/synapses/0_gaba -mv data/raw/synapses/acetylcholine data/raw/synapses/1_acetylcholine -mv data/raw/synapses/glutamate data/raw/synapses/2_glutamate -mv data/raw/synapses/serotonin data/raw/synapses/3_serotonin -mv data/raw/synapses/octopamine data/raw/synapses/4_octopamine -mv data/raw/synapses/dopamine data/raw/synapses/5_dopamine +echo "Training classifier model" +python extras/train_classifier.py + +conda deactivate diff --git a/solution.ipynb b/solution.ipynb index f3f9237..b0b9e5a 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,871 +2,212 @@ "cells": [ { "cell_type": "markdown", - "id": "93f0f7f9", + "id": "30c11df5", "metadata": { - "editable": true, "lines_to_next_cell": 0, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Exercise 9: Knowledge Extraction from a Convolutional Neural Network\n", + "# Exercise 8: Knowledge Extraction from a Pre-trained Neural Network\n", "\n", - "In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data.\n", + "The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.\n", "\n", - "![synister.png](assets/synister.png)\n", + "We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.\n", + "Unlike regular MNIST, our dataset is classified not by number, but by color!\n", "\n", - "### Acknowledgments\n", + "We will:\n", + "1. Load a pre-trained classifier and try applying conventional attribution methods\n", + "2. Train a GAN to create counterfactual images - translating images from one class to another\n", + "3. Evaluate the GAN - see how good it is at fooling the classifier\n", + "4. Create attributions from the counterfactual, and learn the differences between the classes.\n", "\n", - "This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9a25e710", - "metadata": {}, - "source": [ - "
    \n", - "Set your python kernel to 09_knowledge_extraction\n", - "
    " - ] - }, - { - "cell_type": "markdown", - "id": "f9b96c13", - "metadata": {}, - "source": [ - "

    Start here (AKA checkpoint 0)

    \n", + "If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem.\n", + "### Acknowledgments\n", "\n", - "
    " + "This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein.\n" ] }, { "cell_type": "markdown", - "id": "0c339e3d", + "id": "ec2899d4", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "# Part 1: Image Classification\n", - "\n", - "## Training an image classifier\n", - "In this section, we will implement and train a VGG classifier to classify images of synapses into one of six classes, corresponding to the neurotransmitter type that is released at the synapse: GABA, acethylcholine, glutamate, octopamine, serotonin, and dopamine." + "
    \n", + "Set your python kernel to 08_knowledge_extraction\n", + "
    " ] }, { "cell_type": "markdown", - "id": "7f524106", + "id": "2c084b97", "metadata": {}, "source": [ "\n", - "The data we use for this exercise is located in `data/raw/synapses`, where we have one subdirectory for each neurotransmitter type. Look at a few examples to familiarize yourself with the dataset. You will notice that the dataset is not balanced, i.e., we have much more examples of one class versus another one.\n", + "# Part 1: Setup\n", "\n", - "This class imbalance is problematic for training a classifier. Imagine that 99% of our images are of one class, then the classifier would do really well predicting this class all the time, without having learnt anything of substance. It is therefore important to balance the dataset, i.e., present the same number of images per class to the classifier during training.\n", - "\n", - "First, we split the available images into a train, validation, and test dataset with proportions of 0.7, 0.15, and 0.15, respectively. Each image should be returned as a 2D `numpy` array with float values between 0 and 1. The label for each image should be the name of the directory for this class (e.g., `0_gaba`).\n" + "In this part of the notebook, we will load the same dataset as in the previous exercise.\n", + "We will also learn to load one of our trained classifiers from a checkpoint." ] }, { "cell_type": "code", "execution_count": null, - "id": "dca1c9b7", + "id": "9d26a8bb", "metadata": { - "lines_to_next_cell": 2, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "from torch.utils.data import DataLoader, random_split\n", - "from torch.utils.data.sampler import WeightedRandomSampler\n", - "from torchvision.datasets import ImageFolder\n", - "from torchvision import transforms\n", - "import torch\n", - "import numpy as np\n", - "\n", - "transform = transforms.Compose(\n", - " [\n", - " transforms.Grayscale(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.5,), (0.5,)),\n", - " ]\n", - ")\n", - "\n", - "# create a dataset for all images of all classes\n", - "full_dataset = ImageFolder(root=\"data/raw/synapses\", transform=transform)\n", - "\n", - "# Rename the classes\n", - "full_dataset.classes = [x.split(\"_\")[-1] for x in full_dataset.classes]\n", - "class_to_idx = {x.split(\"_\")[-1]: v for x, v in full_dataset.class_to_idx.items()}\n", - "full_dataset.class_to_idx = class_to_idx\n", + "# loading the data\n", + "from classifier.data import ColoredMNIST\n", "\n", - "# randomly split the dataset into train, validation, and test\n", - "num_images = len(full_dataset)\n", - "# ~70% for training\n", - "num_training = int(0.7 * num_images)\n", - "# ~15% for validation\n", - "num_validation = int(0.15 * num_images)\n", - "# ~15% for testing\n", - "num_test = num_images - (num_training + num_validation)\n", - "# split the data randomly (but with a fixed random seed)\n", - "train_dataset, validation_dataset, test_dataset = random_split(\n", - " full_dataset,\n", - " [num_training, num_validation, num_test],\n", - " generator=torch.Generator().manual_seed(23061912),\n", - ")" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "2f4f148f", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "### Creating a Balanced Dataloader\n", - "\n", - "Below define a `sampler` that samples images of classes with skewed probabilities to account for the different number of items in each class.\n", - "\n", - "The sampler\n", - "- Counts the number of samples in each class\n", - "- Gets the weight-per-label as an inverse of the frequency\n", - "- Get the weight-per-sample\n", - "- Create a `WeightedRandomSampler` based on these weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faa2b411", + "id": "f8a5937c", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "# compute class weights in training dataset for balanced sampling\n", - "def balanced_sampler(dataset):\n", - " # Get a list of targets from the dataset\n", - " if isinstance(dataset, torch.utils.data.Subset):\n", - " # A Subset is a specific type of dataset, which does not directly have access to the targets.\n", - " targets = torch.tensor(dataset.dataset.targets)[dataset.indices]\n", - " else:\n", - " targets = dataset.targets\n", - "\n", - " counts = torch.bincount(targets) # Count the number of samples for each class\n", - " label_weights = (\n", - " 1.0 / counts\n", - " ) # Get the weight-per-label as an inverse of the frequency\n", - " weights = label_weights[targets] # Get the weight-per-sample\n", - "\n", - " # Optional: Print the Counts and Weights to make sure lower frequency classes have higher weights\n", - " print(\"Number of images per class:\")\n", - " for c, n, w in zip(full_dataset.classes, counts, label_weights):\n", - " print(f\"\\t{c}:\\tn={n}\\tweight={w}\")\n", + "Some information about the dataset:\n", + "- The dataset is a colored version of the MNIST dataset.\n", + "- Instead of using the digits as classes, we use the colors.\n", + "- There are four classes - the goal of the exercise is to find out what these are.\n", "\n", - " sampler = WeightedRandomSampler(\n", - " weights, len(weights)\n", - " ) # Create a sampler based on these weights\n", - " return sampler\n", - "\n", - "\n", - "sampler = balanced_sampler(train_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "ceb0525e", - "metadata": {}, - "source": [ - "We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n", - "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)." + "Let's plot some examples" ] }, { "cell_type": "code", "execution_count": null, - "id": "a15b4bac", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# this data loader will serve 8 images in a \"mini-batch\" at a time\n", - "dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True, sampler=sampler)" - ] - }, - { - "cell_type": "markdown", - "id": "5892ab7f", + "id": "9c0ce960", "metadata": {}, - "source": [ - "The cell below visualizes a single, randomly chosen batch from the training data loader. Feel free to execute this cell multiple times to get a feeling for the dataset and that your sampler gives batches of evenly distributed synapse types." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5aab255a", - "metadata": { - "lines_to_next_cell": 2, - "tags": [] - }, "outputs": [], "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", + "import matplotlib.pyplot as plt\n", "\n", - "\n", - "def show_batch(x, y):\n", - " fig, axs = plt.subplots(1, x.shape[0], figsize=(14, 14), sharey=True)\n", - " for i in range(x.shape[0]):\n", - " axs[i].imshow(np.squeeze(x[i]), cmap=\"gray\", vmin=-1, vmax=1)\n", - " axs[i].set_title(train_dataset.dataset.classes[y[i].item()])\n", - " axs[i].axis(\"off\")\n", - " plt.show()\n", - "\n", - "\n", - "# show a random batch from the data loader\n", - "# (run this cell repeatedly to see different batches)\n", - "for x, y in dataloader:\n", - " show_batch(x, y)\n", - " break" + "# Show some examples\n", + "fig, axs = plt.subplots(4, 4, figsize=(8, 8))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " x, y = mnist[i]\n", + " x = x.permute((1, 2, 0)) # make channels last\n", + " ax.imshow(x)\n", + " ax.set_title(f\"Class {y}\")\n", + " ax.axis(\"off\")" ] }, { "cell_type": "markdown", - "id": "025648fb", + "id": "0cb834e5", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ - "### Creating a VGG Network, Loss\n", - "\n", - "We will use a VGG network to classify the synapse images. The input to the network will be a 2D image as provided by your dataloader. The output will be a vector of six floats, corresponding to the probability of the input to belong to the six classes.\n", - "\n", - "We have implemented a VGG network below.\n", - "
    " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7e2b968", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Vgg2D(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " input_size,\n", - " fmaps=12,\n", - " downsample_factors=[(2, 2), (2, 2), (2, 2), (2, 2)],\n", - " output_classes=6,\n", - " ):\n", - " super(Vgg2D, self).__init__()\n", - "\n", - " self.input_size = input_size\n", - "\n", - " current_fmaps, h, w = tuple(input_size)\n", - " current_size = (h, w)\n", - "\n", - " features = []\n", - " for i in range(len(downsample_factors)):\n", - " features += [\n", - " torch.nn.Conv2d(current_fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Conv2d(fmaps, fmaps, kernel_size=3, padding=1),\n", - " torch.nn.BatchNorm2d(fmaps),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.MaxPool2d(downsample_factors[i]),\n", - " ]\n", - "\n", - " current_fmaps = fmaps\n", - " fmaps *= 2\n", - "\n", - " size = tuple(\n", - " int(c / d) for c, d in zip(current_size, downsample_factors[i])\n", - " )\n", - " check = (\n", - " s * d == c for s, d, c in zip(size, downsample_factors[i], current_size)\n", - " )\n", - " assert all(check), \"Can not downsample %s by chosen downsample factor\" % (\n", - " current_size,\n", - " )\n", - " current_size = size\n", - "\n", - " self.features = torch.nn.Sequential(*features)\n", - "\n", - " classifier = [\n", - " torch.nn.Linear(current_size[0] * current_size[1] * current_fmaps, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, 4096),\n", - " torch.nn.ReLU(inplace=True),\n", - " torch.nn.Dropout(),\n", - " torch.nn.Linear(4096, output_classes),\n", - " ]\n", - "\n", - " self.classifier = torch.nn.Sequential(*classifier)\n", - "\n", - " def forward(self, raw):\n", - " # compute features\n", - " f = self.features(raw)\n", - " f = f.view(f.size(0), -1)\n", - "\n", - " # classify\n", - " y = self.classifier(f)\n", - "\n", - " return y" - ] - }, - { - "cell_type": "markdown", - "id": "c544bd0d", - "metadata": {}, - "source": [ - "We'll start by creating the VGG with the default parameters and push it to a GPU if there is one available. Then we'll define the training loss and optimizer.\n", - "The training and evaluation loops have been defined for you, so after that just train your network!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c6fca99", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# get the size of our images\n", - "for x, y in train_dataset:\n", - " input_size = x.shape\n", - " break\n", - "\n", - "# create the model to train\n", - "model = Vgg2D(input_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4929dd7f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# use a GPU, if it is available\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model.to(device)\n", - "print(f\"Will use device {device} for training\")" + "We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`.\n", + "Let's load that classifier now!" ] }, { "cell_type": "markdown", - "id": "73e2d8ad", - "metadata": {}, - "source": [ - "

    Task 1.1: Train the VGG Network

    \n", - "\n", - "- Choose a loss\n", - "- Create an Adam optimizer and set its learning rate\n", - "
    " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c29af1d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "loss = ...\n", - "optimizer = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a3fe5b41", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "############################\n", - "# Solution to Task 1.3 #\n", - "############################\n", - "loss = torch.nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)" - ] - }, - { - "cell_type": "markdown", - "id": "6fb96afe", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "The next cell defines some convenience functions for training, validation, and testing:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1f21c05", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from tqdm import tqdm\n", - "\n", - "\n", - "def train(dataloader):\n", - " \"\"\"Train the model for one epoch.\"\"\"\n", - "\n", - " # set the model into train mode\n", - " model.train()\n", - "\n", - " epoch_loss = 0\n", - "\n", - " num_batches = 0\n", - " for x, y in tqdm(dataloader, \"train\"):\n", - " x, y = x.to(device), y.to(device)\n", - " optimizer.zero_grad()\n", - "\n", - " y_pred = model(x)\n", - " l = loss(y_pred, y)\n", - " l.backward()\n", - "\n", - " optimizer.step()\n", - "\n", - " epoch_loss += l\n", - " num_batches += 1\n", - "\n", - " return epoch_loss / num_batches" - ] - }, - { - "cell_type": "markdown", - "id": "9c473df0", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "

    Task 1.2: Create a prediction function

    \n", - "\n", - "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n", - "\n", - "\n", - "TODO\n", - "Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", - "- Get the model output for the batch of data `(x, y)`\n", - "- Turn the model output into a probability\n", - "- Get the class predictions from the probabilities\n", - "- Add the class predictions to a list of all predictions\n", - "- Add the ground truths to a list of all ground truths\n", - "\n", - "
    \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cae63f62", + "id": "a32035d7", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "# TODO: return a paired list of predicted class vs ground-truth to produce a confusion matrix\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", + "

    Task 1.1: Load the classifier

    \n", + "We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:\n", + "- `input_shape`: the shape of the input images, as a tuple\n", + "- `num_classes`: the number of classes in the dataset\n", "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - " #\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " # Get the model output\n", - " # Turn the model output into a probability\n", - " # Get the class predictions from the probabilities\n", - "\n", - " predictions.extend(...) # TODO add predictions to the list\n", - " ground_truths.extend(...) # TODO add ground truths to the list\n", - " return np.array(predictions), np.array(ground_truths)\n", - "\n", - "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" + "Create a dense model with the right inputs and load the weights from the checkpoint.\n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "3f9d4714", + "id": "0146821b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] }, "outputs": [], "source": [ - "#########################\n", - "# Solution for Task 1.4 #\n", - "#########################\n", - "\n", - "from tqdm import tqdm\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "\n", - "def predict(dataset, name, batch_size=32):\n", - " # These data laoders serve images in a \"mini-batch\"\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False)\n", - "\n", - " ground_truths = []\n", - " predictions = []\n", - " for x, y in tqdm(dataloader, name):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " # Get the model output\n", - " logits = model(x)\n", - " # Turn the model output into a probability\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " # Get the class predictions from the probabilities\n", - " batch_predictions = torch.argmax(probs, dim=1)\n", - "\n", - " # append predictions and groundtruth to our big list,\n", - " # converting `tensor` objects to simple values through .item()\n", - " predictions.extend(batch_predictions.cpu().numpy())\n", - " ground_truths.extend(y.cpu().numpy())\n", + "import torch\n", + "from classifier.model import DenseModel\n", "\n", - " return np.array(predictions), np.array(ground_truths)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", - "prediction, ground_truth = predict(test_dataset, \"Test\")\n", - "print(\"Current test accuracy of the network\", accuracy_score(ground_truth, prediction))" + "# Load the model\n", + "model = DenseModel(input_shape=(3, 28, 28), num_classes=4)\n", + "# Load the checkpoint\n", + "checkpoint = torch.load(\"extras/checkpoints/model.pth\")\n", + "model.load_state_dict(checkpoint)\n", + "model = model.to(device)" ] }, { "cell_type": "markdown", - "id": "bfee4910", + "id": "6ecddeb8", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We are ready to train. After each epoch (roughly going through each training image once), we report the training loss and the validation accuracy." + "Don't take my word for it! Let's see how well the classifier does on the test set." ] }, { "cell_type": "code", "execution_count": null, - "id": "41bc31bd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "for epoch in range(3):\n", - " epoch_loss = train(dataloader)\n", - " print(f\"Epoch {epoch}, training loss={epoch_loss}\")\n", - "\n", - " predictions, gt = predict(validation_dataset, \"Validation\")\n", - " accuracy = accuracy_score(gt, predictions)\n", - " print(f\"Epoch {epoch}, validation accuracy={accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc91973f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Let's watch your model train!\n", - "\n", - "\"drawing\"" - ] - }, - { - "cell_type": "markdown", - "id": "7324a440", + "id": "c271ecd9", "metadata": {}, - "source": [ - "And now, let's test it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0770ee", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "57241755", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "If you're unhappy with the accuracy above (which you should be...) we pre-trained a model for you for many more epochs. You can load it with the next cell." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "953cad3a", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, "outputs": [], "source": [ - "# TODO Run this cell if you want a shortcut\n", - "yes_I_want_the_pretrained_model = True\n", - "\n", - "if yes_I_want_the_pretrained_model:\n", - " checkpoint = torch.load(\n", - " \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n", - " )\n", - " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - "\n", - "\n", - "# And check the (hopefully much better) accuracy\n", - "predictions, ground_truths = predict(test_dataset, \"Test\")\n", - "accuracy = accuracy_score(ground_truths, predictions)\n", - "print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "45d26644", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "### Constructing a confusion matrix\n", - "\n", - "We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n", - "\n", - "To understand the performance of the classifier beyond a single accuracy number, we should build a confusion matrix that can more elucidate which classes are more/less misclassified and which classes are those wrong predictions confused with.\n", - "
    \n" - ] - }, - { - "cell_type": "markdown", - "id": "39ae027f", - "metadata": {}, - "source": [ - "Let's plot the confusion matrix." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc315793", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import pandas as pd\n", + "from torch.utils.data import DataLoader\n", "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", - "import numpy as np\n", - "\n", - "\n", - "# Plot confusion matrix\n", - "# orginally from Runqi Yang;\n", - "# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7\n", - "def cm_analysis(y_true, y_pred, names, labels=None, title=None, figsize=(10, 8)):\n", - " \"\"\"\n", - " Generate matrix plot of confusion matrix with pretty annotations.\n", - "\n", - " Parameters\n", - " ----------\n", - " confusion_matrix: np.array\n", - " labels: list\n", - " List of integer values to determine which classes to consider.\n", - " names: string array, name the order of class labels in the confusion matrix.\n", - " use `clf.classes_` if using scikit-learn models.\n", - " with shape (nclass,).\n", - " ymap: dict: any -> string, length == nclass.\n", - " if not None, map the labels & ys to more understandable strings.\n", - " Caution: original y_true, y_pred and labels must align.\n", - " figsize: the size of the figure plotted.\n", - " \"\"\"\n", - " if labels is not None:\n", - " assert len(names) == len(labels)\n", - " cm = confusion_matrix(y_true, y_pred, labels=labels)\n", - " cm_sum = np.sum(cm, axis=1, keepdims=True)\n", - " cm_perc = cm / cm_sum.astype(float) * 100\n", - " annot = np.empty_like(cm).astype(str)\n", - " nrows, ncols = cm.shape\n", - " for i in range(nrows):\n", - " for j in range(ncols):\n", - " c = cm[i, j]\n", - " p = cm_perc[i, j]\n", - " if i == j:\n", - " s = cm_sum[i]\n", - " annot[i, j] = \"%.1f%%\\n%d/%d\" % (p, c, s)\n", - " elif c == 0:\n", - " annot[i, j] = \"\"\n", - " else:\n", - " annot[i, j] = \"%.1f%%\\n%d\" % (p, c)\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " ax = sns.heatmap(\n", - " cm_perc, annot=annot, fmt=\"\", vmax=100, xticklabels=names, yticklabels=names\n", - " )\n", - " ax.set_xlabel(\"Predicted\")\n", - " ax.set_ylabel(\"True\")\n", - " if title:\n", - " ax.set_title(title)\n", - "\n", - "\n", - "names = [\"gaba\", \"acetylcholine\", \"glutamate\", \"serotonine\", \"octopamine\", \"dopamine\"]\n", - "cm_analysis(predictions, ground_truths, names=names)" - ] - }, - { - "cell_type": "markdown", - "id": "3c8cf7bb", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "
    \n", - "

    Questions

    \n", - "\n", - "- What observations can we make from the confusion matrix?\n", - "- Does the classifier do better on some synapse classes than other?\n", - "- If you have time later, which ideas would you try to train a better predictor?\n", - "\n", - "Let us know your thoughts on the course chat.\n", - "
    " - ] - }, - { - "cell_type": "markdown", - "id": "ce4ccb36", - "metadata": {}, - "source": [ - "

    Checkpoint 1

    \n", - "\n", - "We now have:\n", - "- A classifier that is pretty good at predicting neurotransmitters from EM images.\n", "\n", - "This is surprising, since we could not (yet) have made these predictions manually! If you're skeptical, feel free to explore the data a bit more and see for yourself if you can tell the difference betwee, say, GABAergic and glutamatergic synapses.\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", - "So this is an interesting situation: The VGG network knows something we don't quite know. In the next section, we will see how we can find and then visualize the relevant differences between images of different types.\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", "\n", - "This concludes the first section. Let us know on the exercise chat if you have arrived here.\n", - "
    " + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "be1f14b2", + "id": "46a684f4", "metadata": {}, "source": [ - "# Part 2: Masking the relevant part of the image\n", + "# Part 2: Using Integrated Gradients to find what the classifier knows\n", "\n", - "In this section we will make a first attempt at highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" + "In this section we will make a first attempt at highlighting differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier.\n" ] }, { "cell_type": "markdown", - "id": "41464574", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", "\n", - "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", + "Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.\n", "\n", "Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients." ] @@ -874,29 +215,26 @@ { "cell_type": "code", "execution_count": null, - "id": "af08ae72", + "id": "e5b162b7", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ - "x, y = next(iter(dataloader))\n", + "batch_size = 4\n", + "batch = []\n", + "for i in range(4):\n", + " batch.append(next(image for image in mnist if image[1] == i))\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", "x = x.to(device)\n", "y = y.to(device)" ] }, { "cell_type": "markdown", - "id": "9fbf1572", + "id": "6d418ea1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -911,35 +249,8 @@ { "cell_type": "code", "execution_count": null, - "id": "897dd327", + "id": "f93e8067", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from captum.attr import IntegratedGradients\n", - "\n", - "############### Task 2.1 TODO ############\n", - "# Create an integrated gradients object.\n", - "integrated_gradients = ...\n", - "\n", - "# Generated attributions on integrated gradients\n", - "attributions = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27a769fd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -962,243 +273,143 @@ { "cell_type": "code", "execution_count": null, - "id": "31fa10dc", + "id": "e4ba6b3a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ "attributions = (\n", - " attributions.cpu().numpy()\n", - ") # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing" - ] - }, - { - "cell_type": "markdown", - "id": "657bf893", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "Here is an example for an image, and its corresponding attribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c4faa92", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "from captum.attr import visualization as viz\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return 0.5 * image + 0.5\n", - "\n", - "\n", - "def visualize_attribution(attribution, original_image):\n", - " attribution = np.transpose(attribution, (1, 2, 0))\n", - " original_image = np.transpose(unnormalize(original_image), (1, 2, 0))\n", - "\n", - " viz.visualize_image_attr_multiple(\n", - " attribution,\n", - " original_image,\n", - " methods=[\"blended_heat_map\", \"heat_map\"],\n", - " signs=[\"absolute_value\", \"absolute_value\"],\n", - " show_colorbar=True,\n", - " titles=[\"Original and Attribution\", \"Attribution\"],\n", - " use_pyplot=True,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d050712", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "for attr, im in zip(attributions, x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + " attributions.cpu().numpy()\n", + ") # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing" ] }, { "cell_type": "markdown", - "id": "2bd418b1", + "id": "56e432ae", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 2, "tags": [] }, "source": [ - "### Smoothing the attribution into a mask\n", - "\n", - "The attributions that we see are grainy and difficult to interpret because they are a pixel-wise attribution. We apply some smoothing and thresholding on the attributions so that they represent region masks rather than pixel masks. The following code is runnable with no modification." + "Here is an example for an image, and its corresponding attribution." ] }, { "cell_type": "code", "execution_count": null, - "id": "55715f0e", + "id": "9561d46f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "outputs": [], "source": [ - "import cv2\n", - "import copy\n", - "\n", - "\n", - "def smooth_attribution(attrs, struc=10, sigma=11):\n", - " # Morphological closing and Gaussian Blur\n", - " kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc))\n", - " mask = cv2.morphologyEx(attrs[0], cv2.MORPH_CLOSE, kernel)\n", - " mask_cp = copy.deepcopy(mask)\n", - " mask_weight = cv2.GaussianBlur(mask_cp.astype(float), (sigma, sigma), 0)\n", - " return mask_weight[np.newaxis]\n", - "\n", + "from captum.attr import visualization as viz\n", + "import numpy as np\n", "\n", - "def get_mask(attrs, threshold=0.5):\n", - " smoothed = smooth_attribution(attrs)\n", - " return smoothed > (threshold * smoothed.max())\n", "\n", + "def visualize_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", "\n", - "def interactive_attribution(idx=0):\n", - " image = x[idx].cpu().numpy()\n", - " attrs = attributions[idx]\n", - " mask = smooth_attribution(attrs)\n", - " visualize_attribution(mask, image)\n", - " return" + " viz.visualize_image_attr_multiple(\n", + " attribution,\n", + " original_image,\n", + " methods=[\"original_image\", \"heat_map\"],\n", + " signs=[\"all\", \"absolute_value\"],\n", + " show_colorbar=True,\n", + " titles=[\"Image\", \"Attribution\"],\n", + " use_pyplot=True,\n", + " )" ] }, { - "cell_type": "markdown", - "id": "33598839", + "cell_type": "code", + "execution_count": null, + "id": "a55fe8ec", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, + "outputs": [], "source": [ - "

    Task 2.2 Visualizing the results

    \n", - "\n", - "The code above creates a small widget to interact with the results of this analysis. Look through the samples for a while before answering the questions below.\n", - "
    " + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_attribution(attr, im)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "490db899", + "cell_type": "markdown", + "id": "1d8c03a0", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 2 }, - "outputs": [], "source": [ - "from ipywidgets import interact\n", "\n", - "interact(\n", - " interactive_attribution,\n", - " idx=(0, dataloader.batch_size - 1),\n", - ")" + "The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is.\n", + "As you can see, it is pretty good at recognizing the number within the image.\n", + "As we know, however, it is not the digit itself that is important for the classification, it is the color!\n", + "Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters." ] }, { "cell_type": "markdown", - "id": "18dce2c2", + "id": "2a24c70a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "HELP! I Can't see any interactive setup!!\n", - "\n", - "I got you... just uncomment the next cell and run it to see all of the samples at once." + "Something is slightly unfair about this visualization though.\n", + "We are visualizing as if it were grayscale, but both our images and our attributions are in color!\n", + "Can we learn more from the attributions if we visualize them in color?" ] }, { "cell_type": "code", "execution_count": null, - "id": "eda303d1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "6e875faa", + "metadata": {}, "outputs": [], "source": [ - "# HELP! I Can't see any interative setup!!!\n", - "# for attr, im in zip(attributions, x.cpu().numpy()):\n", - "# visualize_attribution(smooth_attribution(attr), im)" + "def visualize_color_attribution(attribution, original_image):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + "\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n", + " ax1.imshow(original_image)\n", + " ax1.set_title(\"Image\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "09cc4c08", + "id": "3f73608f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
    \n", - "

    Questions

    \n", - "\n", - "- Are there some recognisable objects or parts of the synapse that show up in several examples?\n", - "- Are there some objects that seem secondary because they are less strongly highlighted?\n", + "We get some better clues when looking at the attributions in color.\n", + "The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image.\n", + "Just based on this, however, we don't get much more information than we got from the images themselves.\n", "\n", - "Tell us what you see on the chat!\n", - "
    " + "If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier." ] }, { "cell_type": "markdown", - "id": "bd34722b", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", - "### Changing the basline\n", + "### Changing the baseline\n", "\n", "Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*.\n", "The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output.\n", @@ -1212,7 +423,7 @@ "```\n", "To get more details about how to include the baseline.\n", "\n", - "Try using the code above to change the baseline and see how this affects the output.\n", + "Try using the code below to change the baseline and see how this affects the output.\n", "\n", "1. Random noise as a baseline\n", "2. A blurred/noisy version of the original image as a baseline." @@ -1220,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "53feb16f", + "id": "dbb04b6f", "metadata": {}, "source": [ "

    Task 2.3: Use random noise as a baseline

    \n", @@ -1232,35 +443,8 @@ { "cell_type": "code", "execution_count": null, - "id": "9d6c65e1", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Baseline\n", - "random_baselines = ... # TODO Change\n", - "# Generate the attributions\n", - "attributions_random = integrated_gradients.attribute(...) # TODO Change\n", - "\n", - "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f3f07eb8", + "id": "cde2c2ff", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -1278,18 +462,15 @@ ")\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "e97700bc", + "id": "bf7e934c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ @@ -1302,37 +483,8 @@ { "cell_type": "code", "execution_count": null, - "id": "b9e5b23e", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# TODO Import required function\n", - "\n", - "# Baseline\n", - "blurred_baselines = ... # TODO Create blurred version of the images\n", - "# Generate the attributions\n", - "attributions_blurred = integrated_gradients.attribute(...) # TODO Fill\n", - "\n", - "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ba5b4ff", + "id": "a0cb195e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [ "solution" ] @@ -1352,38 +504,35 @@ ")\n", "\n", "# Plotting\n", - "for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):\n", - " visualize_attribution(attr, im)" + "for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):\n", + " print(f\"Class {lbl}\")\n", + " visualize_color_attribution(attr, im)" ] }, { "cell_type": "markdown", - "id": "5cdde305", + "id": "db46361b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ "

    Questions

    \n", - "\n", - "- Are any of the features consistent across baselines? Why do you think that is?\n", - "- What baseline do you like best so far? Why?\n", - "- If you were to design an ideal baseline, what would you choose?\n", + "
      \n", + "
    • What baseline do you like best so far? Why?
    • \n", + "
    • Why do you think some baselines work better than others?
    • \n", + "
    • If you were to design an ideal baseline, what would you choose?
    • \n", + "
    \n", "
    " ] }, { "cell_type": "markdown", - "id": "1a15cf83", + "id": "e9105812", "metadata": {}, "source": [ "

    BONUS Task: Using different attributions.

    \n", "\n", "\n", - "\n", "[`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms.\n", "\n", "Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other?\n", @@ -1392,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "9bb8d816", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

    Checkpoint 2

    \n", @@ -1400,28 +549,27 @@ "\n", "At this point we have:\n", "\n", - "- Trained a classifier that can predict neurotransmitters from EM-slices of synapses.\n", - "- Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients.\n", + "- Loaded a classifier that classifies MNIST-like images by color, but we don't know how!\n", + "- Tried applying Integrated Gradients to find out what the classifier is looking at - with little success.\n", "- Discovered the effect of changing the baseline on the output of integrated gradients.\n", "\n", + "Coming up in the next section, we will learn how to create counterfactual images.\n", + "These images will change *only what is necessary* in order to change the classification of the image.\n", + "We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature.\n", "
    " ] }, { "cell_type": "markdown", - "id": "a31ef8d6", + "id": "531169e5", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ "# Part 3: Train a GAN to Translate Images\n", "\n", - "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations.\n", + "To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n", + "This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n", "\n", "**What is a counterfactual?**\n", "\n", @@ -1436,1809 +584,1157 @@ "\n", "**Counterfactual synapses**\n", "\n", - "In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9089850c", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def class_dir(name):\n", - " return f\"{class_to_idx[name]}_{name}\"\n", - "\n", - "\n", - "classes = [\"gaba\", \"acetylcholine\"]" + "In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class." ] }, { "cell_type": "markdown", - "id": "36b89586", + "id": "331e56d6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "## Training a GAN\n", - "\n", - "Yes, really!" - ] - }, - { - "cell_type": "markdown", - "id": "aff1b90b", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "

    Creating a specialized dataset

    \n", - "\n", - "The CycleGAN works on only 2 classes at a time, but our full dataset has 6. Below, we will use the `Subset` dataset from `torch.utils.data` to get the data from these two classes.\n", - "\n", - "A `Subset` is created as follows:\n", - "```\n", - "subset = Subset(dataset, indices)\n", - "```\n", + "### The model\n", + "![stargan.png](assets/stargan.png)\n", "\n", - "And the chosen indices can be obtained again using `subset.indices`.\n", + "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", + "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "Run the cell below to generate the datasets:\n", - "- `gan_train_dataset`\n", - "- `gan_test_dataset`\n", - "- `gan_val_dataset`\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", + "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", + "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", - "We will use them below to train the CycleGAN.\n", - "
    " + "Let's start by creating these!" ] }, { "cell_type": "code", "execution_count": null, - "id": "a8981d1e", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "301ee289", + "metadata": {}, "outputs": [], "source": [ - "# Utility functions to get a subset of classes\n", - "def get_indices(dataset, classes):\n", - " \"\"\"Get the indices of elements of classA and classB in the dataset.\"\"\"\n", - " indices = []\n", - " for cl in classes:\n", - " indices.append(torch.tensor(dataset.targets) == class_to_idx[cl])\n", - " logical_or = sum(indices) > 0\n", - " return torch.where(logical_or)[0]\n", - "\n", - "\n", - "def set_intersection(a_indices, b_indices):\n", - " \"\"\"Get intersection of two sets\n", - "\n", - " Parameters\n", - " ----------\n", - " a_indices: torch.Tensor\n", - " b_indices: torch.Tensor\n", - "\n", - " Returns\n", - " -------\n", - " intersection: torch.Tensor\n", - " The elements contained in both a_indices and b_indices.\n", - " \"\"\"\n", - " a_cat_b, counts = torch.cat([a_indices, b_indices]).unique(return_counts=True)\n", - " intersection = a_cat_b[torch.where(counts.gt(1))]\n", - " return intersection\n", - "\n", - "\n", - "# Getting training, testing, and validation indices\n", - "gan_idx = get_indices(full_dataset, classes)\n", + "from dlmbl_unet import UNet\n", + "from torch import nn\n", "\n", - "gan_train_idx = set_intersection(torch.tensor(train_dataset.indices), gan_idx)\n", - "gan_test_idx = set_intersection(torch.tensor(test_dataset.indices), gan_idx)\n", - "gan_val_idx = set_intersection(torch.tensor(validation_dataset.indices), gan_idx)\n", "\n", - "# Checking that the subsets are complete\n", - "assert len(gan_train_idx) + len(gan_test_idx) + len(gan_val_idx) == len(gan_idx)\n", + "class Generator(nn.Module):\n", "\n", - "# Generate three datasets based on the above indices.\n", - "from torch.utils.data import Subset\n", + " def __init__(self, generator, style_encoder):\n", + " super().__init__()\n", + " self.generator = generator\n", + " self.style_encoder = style_encoder\n", "\n", - "gan_train_dataset = Subset(full_dataset, gan_train_idx)\n", - "gan_test_dataset = Subset(full_dataset, gan_test_idx)\n", - "gan_val_dataset = Subset(full_dataset, gan_val_idx)" + " def forward(self, x, y):\n", + " \"\"\"\n", + " x: torch.Tensor\n", + " The source image\n", + " y: torch.Tensor\n", + " The style image\n", + " \"\"\"\n", + " style = self.style_encoder(y)\n", + " # Concatenate the style vector with the input image\n", + " style = style.unsqueeze(-1).unsqueeze(-1)\n", + " style = style.expand(-1, -1, x.size(2), x.size(3))\n", + " x = torch.cat([x, style], dim=1)\n", + " return self.generator(x)" ] }, { "cell_type": "markdown", - "id": "479b5de4", + "id": "4ce023f6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "### The model\n", + "

    Task 3.1: Create the models

    \n", "\n", - "![cycle.png](assets/cyclegan.png)\n", + "We are going to create the models for the generator, discriminator, and style mapping.\n", "\n", - "In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine).\n", - "\n", - "It has two generators:\n", - " - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`.\n", - " - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`.\n", - "\n", - "\n", - "When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators.\n", - "For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA.\n", - "This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image.\n", - "\n", - "But how do we force the generators to change the class of the input image? We use a discriminator for each.\n", - " - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it.\n", - " - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it." + "Given the Generator structure above, fill in the missing parts for the unet and the style mapping." ] }, { "cell_type": "code", "execution_count": null, - "id": "d308b66b", + "id": "b491022a", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "solution" + ] }, "outputs": [], "source": [ - "from torch import nn\n", - "import functools\n", - "from cycle_gan.models.networks import ResnetGenerator, NLayerDiscriminator, GANLoss\n", - "\n", - "\n", - "class CycleGAN(nn.Module):\n", - " \"\"\"Cycle GAN\n", - "\n", - " Has:\n", - " - Two class names\n", - " - Two Generators\n", - " - Two Discriminators\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self, class1, class2, input_nc=1, output_nc=1, ngf=64, ndf=64, use_dropout=False\n", - " ):\n", - " \"\"\"\n", - " class1: str\n", - " Label of the first class\n", - " class2: str\n", - " Label of the second class\n", - " \"\"\"\n", - " super().__init__()\n", - " norm_layer = functools.partial(\n", - " nn.InstanceNorm2d, affine=False, track_running_stats=False\n", - " )\n", - " self.classes = [class1, class2]\n", - " self.inverse_keys = {\n", - " class1: class2,\n", - " class2: class1,\n", - " } # i.e. what is the other key?\n", - " self.generators = nn.ModuleDict(\n", - " {\n", - " classname: ResnetGenerator(\n", - " input_nc,\n", - " output_nc,\n", - " ngf,\n", - " norm_layer=norm_layer,\n", - " use_dropout=use_dropout,\n", - " n_blocks=9,\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - " self.discriminators = nn.ModuleDict(\n", - " {\n", - " classname: NLayerDiscriminator(\n", - " input_nc, ndf, n_layers=3, norm_layer=norm_layer\n", - " )\n", - " for classname in self.classes\n", - " }\n", - " )\n", - "\n", - " def forward(self, x, train=True):\n", - " \"\"\"Creates fakes from the reals.\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " train: boolean\n", - " If false, only the counterfactuals are generated and returned.\n", - " Defaults to True.\n", - "\n", - " Returns\n", - " -------\n", - " fakes: dict\n", - " classname -> images of counterfactual images\n", - " identities: dict\n", - " classname -> images of images passed through their corresponding generator, if train is True\n", - " For example, images of class1 are passed through the generator that creates class1.\n", - " These should be identical to the input.\n", - " Not returned if `train` is `False`\n", - " reconstructions\n", - " classname -> images of reconstructed images (full cycle), if train is True.\n", - " Not returned if `train` is `False`\n", - " \"\"\"\n", - " fakes = {}\n", - " identities = {}\n", - " reconstructions = {}\n", - " for k, batch in x.items():\n", - " inv_k = self.inverse_keys[k]\n", - " # Counterfactual: class changes\n", - " fakes[inv_k] = self.generators[inv_k](batch)\n", - " if train:\n", - " # From counterfactual back to original, class changes again\n", - " reconstructions[k] = self.generators[k](fakes[inv_k])\n", - " # Identites: class does not change\n", - " identities[k] = self.generators[k](batch)\n", - " if train:\n", - " return fakes, identities, reconstructions\n", - " return fakes\n", - "\n", - " def discriminate(self, x):\n", - " \"\"\"Get discriminator opinion on x\n", - "\n", - " Parameters\n", - " ----------\n", - " x: dict\n", - " classname -> images\n", - " \"\"\"\n", - " discrimination = {}\n", - " for k, batch in x.items():\n", - " discrimination[k] = self.discriminators[k](batch)\n", - " return discrimination" + "# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment.\n", + "# Choose your own setup to see what works for you.\n", + "style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", + "unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())\n", + "generator = Generator(unet, style_encoder=style_encoder)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "09c3fa55", + "cell_type": "markdown", + "id": "16f87104", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, - "outputs": [], "source": [ - "cyclegan = CycleGAN(*classes)\n", - "cyclegan.to(device)\n", - "print(f\"Will use device {device} for training\")" + "

    Hyper-parameter choices

    \n", + "
      \n", + "
    • Are any of the hyperparameters you choose above constrained in some way?
    • \n", + "
    • What would happen if you chose a depth of 10 for the UNet?
    • \n", + "
    • Is there a minimum size for the style space? Why or why not?
    • \n", + "
    " ] }, { "cell_type": "markdown", - "id": "f91db612", + "id": "9f1d1149", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "You will notice above that the `CycleGAN` takes an input in the form of a dictionary, but our datasets and data-loaders return the data in the form of two tensors. Below are two utility functions that will swap from data from one to the other." + "

    Task 3.2: Create the discriminator

    \n", + "\n", + "We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.\n", + "The discriminator will take as input either a real image or a fake image.\n", + "Fill in the following code to create a discriminator that can classify the images into the correct number of classes.\n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "b6d5d5ee", + "id": "71695d57", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "solution" + ] }, "outputs": [], "source": [ - "# Utility function to go to/from dictionaries/x,y tensors\n", - "def get_as_xy(dictionary):\n", - " x = torch.cat([arr for arr in dictionary.values()])\n", - " y = []\n", - " for k, v in dictionary.items():\n", - " val = class_labels[k]\n", - " y += [\n", - " val,\n", - " ] * len(v)\n", - " y = torch.Tensor(y).to(x.device)\n", - " return x, y\n", - "\n", - "\n", - "def get_as_dictionary(x, y):\n", - " dictionary = {}\n", - " for k in classes:\n", - " val = class_to_idx[k]\n", - " # Get all of the indices for this class\n", - " this_class_indices = torch.where(y == val)\n", - " dictionary[k] = x[this_class_indices]\n", - " return dictionary" + "discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)" ] }, { "cell_type": "markdown", - "id": "8d48e4af", + "id": "231a5202", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "\n", - "### Creating a training loop\n", - "\n", - "Now that we have a model, our next task is to create a training loop for the CycleGAN. This is a bit more difficult than the training loop for our classifier.\n", - "\n", - "Here are some of the things to keep in mind during the next task.\n", - "\n", - "1. The CycleGAN is (obviously) a GAN: a Generative Adversarial Network. What makes an adversarial network \"adversarial\" is that two different networks are working against each other. The loss that is used to optimize this is in our exercise `criterionGAN`. Although the specifics of this loss is beyond the score of this notebook, the idea is simple: the `criterionGAN` compares the output of the discriminator to a boolean-valued target. If we want the discriminator to think that it has seen a real image, we set the target to `True`. If we want the discriminator to think that it has seen a generated image, we set the target to `False`. Note that it isn't important here whether the image *is* real, but **whether we want the discriminator to think it is real at that point**. (This will be important very soon 😉)\n", - "\n", - "2. Since the two networks are fighting each other, it is important to make sure that neither of them can cheat with information espionage. The CycleGAN implementation below is a turn-by-turn fight: we train the generator(s) and the discriminator(s) in alternating steps. When a model is not training, we will restrict its access to information by using `set_requries_grad` to `False`." + "Let's move all models onto the GPU" ] }, { "cell_type": "code", "execution_count": null, - "id": "8482184f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "c0a2d54d", + "metadata": {}, "outputs": [], "source": [ - "from cycle_gan.util.image_pool import ImagePool" + "generator = generator.to(device)\n", + "discriminator = discriminator.to(device)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "53c14194", + "cell_type": "markdown", + "id": "4540ef18", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "criterionIdt = nn.L1Loss()\n", - "criterionCycle = nn.L1Loss()\n", - "criterionGAN = GANLoss(\"lsgan\")\n", - "criterionGAN.to(device)\n", - "\n", - "lambda_idt = 1\n", - "pool_size = 32\n", + "## Training a GAN\n", "\n", - "lambdas = {k: 1 for k in classes}\n", - "image_pools = {classname: ImagePool(pool_size) for classname in classes}\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", "\n", - "optimizer_g = torch.optim.Adam(cyclegan.generators.parameters(), lr=1e-4)\n", - "optimizer_d = torch.optim.Adam(cyclegan.discriminators.parameters(), lr=1e-4)" + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" ] }, { - "cell_type": "markdown", - "id": "706a5f18", + "cell_type": "code", + "execution_count": null, + "id": "b9fc6671", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, + "outputs": [], "source": [ - "

    Task 3.1: Set up the training losses and gradients

    \n", - "\n", - "In the code below, there are several spots with multiple options. Choose from among these, and delete or comment out the incorrect option.\n", - "1. In `generator_step`: Choose whether the target to the`criterionGAN` should be `True` or `False`.\n", - "2. In `discriminator_step`: Choose the target to the `criterionGAN` (note that there are two this time, one for the real images and one for the generated images)\n", - "3. In `train_gan`: `set_requires_grad` correctly.\n", - "
    " + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "9d36c59f", + "cell_type": "markdown", + "id": "196daf45", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " # loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " # loss_real = criterionGAN(preds_real[k], True)\n", - " # loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
    \n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "b43ee77c", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] - }, + "id": "1e9ddd12", + "metadata": {}, "outputs": [], "source": [ - "# Solution\n", - "def set_requires_grad(module, value=True):\n", - " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", - " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "def generator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for generators G_X and G_Y\"\"\"\n", - " # Get all generated images\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " # Get discriminator opinion\n", - " discrimination = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " # Identity loss\n", - " # G_A should be identity if real_B is fed: ||G_A(B) - B||\n", - " loss_idt = criterionIdt(identities[k], reals[k]) * lambdas[k] * lambda_idt\n", - "\n", - " # GAN loss D_A(G_A(A))\n", - " #################### TODO Choice 1 #####################\n", - " # OPTION 1\n", - " # loss_G = criterionGAN(discrimination[k], False)\n", - " # OPTION2\n", - " loss_G = criterionGAN(discrimination[k], True)\n", - " #########################################################\n", - "\n", - " # Forward cycle loss || G_B(G_A(A)) - A||\n", - " loss_cycle = criterionCycle(reconstructions[k], reals[k]) * lambdas[k]\n", - " # combined loss and calculate gradients\n", - " loss += loss_G + loss_cycle + loss_idt\n", - " loss.backward()\n", - "\n", - "\n", - "def discriminator_step(cyclegan, reals):\n", - " \"\"\"Calculate the loss for the discriminators D_X and D_Y\"\"\"\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " preds_real = cyclegan.discriminate(reals)\n", - " # Get fakes from pool\n", - " fakes = {k: v.detach() for k, v in fakes.items()}\n", - " preds_fake = cyclegan.discriminate(fakes)\n", - " loss = 0\n", - " for k in classes:\n", - " #################### TODO Choice 2 #####################\n", - " # OPTION 1\n", - " loss_real = criterionGAN(preds_real[k], True)\n", - " loss_fake = criterionGAN(preds_fake[k], False)\n", - " # OPTION 2\n", - " # loss_real = criterionGAN(preds_real[k], False)\n", - " # loss_fake = criterionGAN(preds_fake[k], True)\n", - " #########################################################\n", - "\n", - " loss += (loss_real + loss_fake) * 0.5\n", - " loss.backward()\n", - "\n", - "\n", - "def train_gan(reals):\n", - " \"\"\"Optimize the network parameters on a batch of images.\n", - "\n", - " reals: Dict[str, torch.Tensor]\n", - " Classname -> Tensor dictionary of images.\n", - " \"\"\"\n", - " #################### TODO Choice 3 #####################\n", - " # OPTION 1\n", - " set_requires_grad(cyclegan.generators, True)\n", - " set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " # set_requires_grad(cyclegan.generators, False)\n", - " # set_requires_grad(cyclegan.discriminators, True)\n", - " ##########################################################\n", - "\n", - " optimizer_g.zero_grad()\n", - " generator_step(cyclegan, reals)\n", - " optimizer_g.step()\n", - "\n", - " #################### TODO (still) choice 3 #####################\n", - " # OPTION 1\n", - " # set_requires_grad(cyclegan.generators, True)\n", - " # set_requires_grad(cyclegan.discriminators, False)\n", - " # OPTION 2\n", - " set_requires_grad(cyclegan.generators, False)\n", - " set_requires_grad(cyclegan.discriminators, True)\n", - " #################################################################\n", - "\n", - " optimizer_d.zero_grad()\n", - " discriminator_step(cyclegan, reals)\n", - " optimizer_d.step()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "30b90f36", + "id": "eade7df1", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "a6e2d5a8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "1deb8b8b", + "metadata": {}, "outputs": [], "source": [ - "def plot_gan_output(sample=None):\n", - " # Get the input from the test dataset\n", - " if sample is None:\n", - " i = np.random.randint(len(gan_test_dataset))\n", - " x, y = gan_test_dataset[i]\n", - " x = x.to(device)\n", - " reals = {classes[y]: x}\n", - " else:\n", - " reals = sample\n", - "\n", - " with torch.no_grad():\n", - " fakes, identities, reconstructions = cyclegan(reals)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " for i in range(len(reals[k])):\n", - " fig, (ax, ax_fake, ax_id, ax_recon) = plt.subplots(1, 4)\n", - " ax.imshow(reals[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_fake.imshow(fakes[inv_k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_id.imshow(identities[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " ax_recon.imshow(reconstructions[k][i].squeeze().cpu(), cmap=\"gray\")\n", - " # Name the axes\n", - " ax.set_title(f\"{k.capitalize()}\")\n", - " ax_fake.set_title(\"Counterfactual\")\n", - " ax_id.set_title(\"Identity\")\n", - " ax_recon.set_title(\"Reconstruction\")\n", - " for ax in [ax, ax_fake, ax_id, ax_recon]:\n", - " ax.axis(\"off\")" + "cycle_loss_fn = nn.L1Loss()" ] }, { "cell_type": "markdown", - "id": "519aba30", + "id": "ba4a7f7f", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

    Task 3.2: Training!

    \n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", - "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
    " + "To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`." ] }, { "cell_type": "code", "execution_count": null, - "id": "597f44ce", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "b5b3d5dc", + "metadata": {}, "outputs": [], "source": [ - "# Get a balanced sampler that only considers the two classes\n", - "sampler = balanced_sampler(gan_train_dataset)\n", + "from torch.utils.data import DataLoader\n", + "\n", "dataloader = DataLoader(\n", - " gan_train_dataset, batch_size=8, drop_last=True, sampler=sampler\n", - ")" + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "7370994c", + "cell_type": "markdown", + "id": "a029e923", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "# Number of iterations to train for (note: this is not *nearly* enough to get ideal results)\n", - "iterations = 500\n", - "# Determines how often to plot outputs to see how the network is doing. I recommend scaling your `print_every` to your `iterations`.\n", - "# For example, if you're running `iterations=5` you can `print_every=1`, but `iterations=1000` and `print_every=1` will be a lot of prints.\n", - "print_every = 100" + "As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.\n", + "Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!\n", + "`set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`)." ] }, { "cell_type": "code", "execution_count": null, - "id": "861dedd4", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "54b4de87", + "metadata": {}, "outputs": [], "source": [ - "for i in tqdm(range(iterations)):\n", - " x, y = next(iter(dataloader))\n", - " x = x.to(device)\n", - " y = y.to(device)\n", - " real = get_as_dictionary(x, y)\n", - " train_gan(real)\n", - " if i % print_every == 0:\n", - " cyclegan.eval() # Set to eval to speed up the plotting\n", - " plot_gan_output(sample=real)\n", - " cyclegan.train() # Set back to train!\n", - " plt.show()" + "def set_requires_grad(module, value=True):\n", + " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = value" ] }, { "cell_type": "markdown", - "id": "09c3f362", + "id": "014e484e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "...this time again.\n", + "Another consequence of adversarial training is that it is very unstable.\n", + "While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.\n", + "To force some stability back into the training, we will use Exponential Moving Averages (EMA).\n", "\n", - "\"drawing\"" + "In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update.\n", + "A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period.\n", + "Each epoch, we will then copy the EMA model's weights back to the generator.\n", + "This is a common technique used in GAN training to stabilize the training process.\n", + "Pay attention to what this does to the loss during the training process!" ] }, { - "cell_type": "markdown", - "id": "6ee205dd", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "f6344c83", + "metadata": {}, + "outputs": [], "source": [ - "

    Checkpoint 3

    \n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", - "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", + "from copy import deepcopy\n", + "\n", + "\n", + "def exponential_moving_average(model, ema_model, beta=0.999):\n", + " \"\"\"Update the EMA model's parameters with an exponential moving average\"\"\"\n", + " for param, ema_param in zip(model.parameters(), ema_model.parameters()):\n", + " ema_param.data.mul_(beta).add_((1 - beta) * param.data)\n", + "\n", "\n", - "You know the drill... let us know on the exercise chat!\n", - "
    " + "def copy_parameters(source_model, target_model):\n", + " \"\"\"Copy the parameters of a model to another model\"\"\"\n", + " for param, target_param in zip(\n", + " source_model.parameters(), target_model.parameters()\n", + " ):\n", + " target_param.data.copy_(param.data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08b7b3af", + "metadata": {}, + "outputs": [], + "source": [ + "generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))\n", + "generator_ema = generator_ema.to(device)" ] }, { "cell_type": "markdown", - "id": "765089a1", + "id": "23fbf680", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "

    Task 3.3: Training!

    \n", + "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", + "Comment out the option that you think will not work.\n", + "
      \n", + "
    • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
    • \n", + "
    • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
    • \n", + "
    • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
    • \n", + ".
    • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
    • \n", + "
    \n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
    " ] }, { "cell_type": "markdown", - "id": "8959c219", + "id": "9cb8281d", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "\n", - "## That was fun!... let's load a pre-trained model\n", - "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋" ] }, { "cell_type": "code", "execution_count": null, - "id": "0fd97600", + "id": "699b3220", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 2, + "tags": [ + "solution" + ] }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", - "\n", + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "def load_pretrained(model, path, classA, classB):\n", - " \"\"\"Load the pre-trained models from the path\"\"\"\n", - " directory = Path(path).expanduser() / f\"{classA}_{classB}\"\n", - " # Load generators\n", - " model.generators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_A.pth\")\n", - " )\n", - " model.generators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_G_B.pth\")\n", - " )\n", - " # Load discriminators\n", - " model.discriminators[classA].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_A.pth\")\n", - " )\n", - " model.discriminators[classB].load_state_dict(\n", - " torch.load(directory / \"latest_net_D_B.pth\")\n", - " )\n", "\n", - "\n", - "load_pretrained(cyclegan, \"./checkpoints/synapses/cycle_gan/\", *classes)" - ] - }, - { - "cell_type": "markdown", - "id": "ee456f57", + "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", + "for epoch in range(15):\n", + " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " # get the target y by shuffling the classes\n", + " # get the style sources by random sampling\n", + " random_index = torch.randperm(len(y))\n", + " x_style = x[random_index].clone()\n", + " y_target = y[random_index].clone()\n", + "\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " optimizer_g.zero_grad()\n", + " # Get the fake image\n", + " x_fake = generator(x, x_style)\n", + " # Try to cycle back\n", + " x_cycled = generator(x_fake, x)\n", + " # Discriminate\n", + " discriminator_x_fake = discriminator(x_fake)\n", + " # Losses to train the generator\n", + "\n", + " # 1. make sure the image can be reconstructed\n", + " cycle_loss = cycle_loss_fn(x, x_cycled)\n", + " # 2. make sure the discriminator is fooled\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", + "\n", + " # Optimize the generator\n", + " (cycle_loss + adv_loss).backward()\n", + " optimizer_g.step()\n", + "\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + " optimizer_d.zero_grad()\n", + " #\n", + " discriminator_x = discriminator(x)\n", + " discriminator_x_fake = discriminator(x_fake.detach())\n", + " # Losses to train the discriminator\n", + " # 1. make sure the discriminator can tell real is real\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", + " #\n", + " disc_loss = (real_loss + fake_loss) * 0.5\n", + " disc_loss.backward()\n", + " # Optimize the discriminator\n", + " optimizer_d.step()\n", + "\n", + " losses[\"cycle\"].append(cycle_loss.item())\n", + " losses[\"adv\"].append(adv_loss.item())\n", + " losses[\"disc\"].append(disc_loss.item())\n", + " exponential_moving_average(generator, generator_ema)\n", + " # Copy the EMA model's parameters to the generator\n", + " copy_parameters(generator_ema, generator)" + ] + }, + { + "cell_type": "markdown", + "id": "4c25819b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Once training is complete, we can plot the losses to see how well the model is doing." ] }, { "cell_type": "code", "execution_count": null, - "id": "20adc855", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "0d64d32d", + "metadata": {}, "outputs": [], "source": [ - "for i in range(5):\n", - " plot_gan_output()" + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "dfa1b783", + "id": "326ba2b5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "We're going to apply the CycleGAN to our test dataset, and save the results to be reused later." + "

    Questions

    \n", + "
      \n", + "
    • Do the losses look like what you expected?
    • \n", + "
    • How do these losses differ from the losses you would expect from a classifier?
    • \n", + "
    • Based only on the losses, do you think the model is doing well?
    • \n", + "
    " ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0887b0da", + "cell_type": "markdown", + "id": "3e58ca01", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, - "outputs": [], "source": [ - "dataloader = DataLoader(gan_test_dataset, batch_size=32)" + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "67b7c1e8", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "1c522efa", + "metadata": {}, "outputs": [], "source": [ - "from skimage.io import imsave\n", - "\n", - "\n", - "def unnormalize(image):\n", - " return ((0.5 * image + 0.5) * 255).astype(np.uint8)\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def apply_gan(dataloader, directory):\n", - " \"\"\"Run CycleGAN on a dataloader and save images to a directory.\"\"\"\n", - " directory = Path(directory)\n", - " inverse_keys = cyclegan.inverse_keys\n", - " cyclegan.eval()\n", - " batch_size = dataloader.batch_size\n", - " n_sample = 0\n", - " for batch, (x, y) in enumerate(tqdm(dataloader)):\n", - " reals = get_as_dictionary(x.to(device), y.to(device))\n", - " fakes, _, recons = cyclegan(reals)\n", - " for k in reals.keys():\n", - " inv_k = inverse_keys[k]\n", - " (directory / f\"real/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"reconstructed/{k}\").mkdir(parents=True, exist_ok=True)\n", - " (directory / f\"counterfactual/{k}\").mkdir(parents=True, exist_ok=True)\n", - " for i, (im_real, im_fake, im_recon) in enumerate(\n", - " zip(reals[k], fakes[inv_k], recons[k])\n", - " ):\n", - " # Save real synapse images\n", - " imsave(\n", - " directory / f\"real/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_real.cpu().numpy().squeeze()),\n", - " )\n", - " # Save fake synapse images\n", - " imsave(\n", - " directory / f\"reconstructed/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_recon.cpu().numpy().squeeze()),\n", - " )\n", - " # Save counterfactual synapse images\n", - " imsave(\n", - " directory / f\"counterfactual/{k}/{k}_{inv_k}_{n_sample}.png\",\n", - " unnormalize(im_fake.cpu().numpy().squeeze()),\n", - " )\n", - " # Count\n", - " n_sample += 1\n", - " return" + "idx = 0\n", + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[0].set_title(\"Input image\")\n", + "axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[1].set_title(\"Style image\")\n", + "axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[2].set_title(\"Generated image\")\n", + "axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n", + "axs[3].set_title(\"Cycled image\")\n", + "\n", + "for ax in axs:\n", + " ax.axis(\"off\")\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "0b4bfcf0", + "id": "30b6dac9", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "a3ecbc7b", + "metadata": { + "tags": [] + }, "source": [ - "apply_gan(dataloader, \"test_images/\")" + "

    Checkpoint 3

    \n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", + "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", + "\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", + "
    " ] }, { - "cell_type": "code", - "execution_count": null, - "id": "2eb0e50e", + "cell_type": "markdown", + "id": "e6bdaecb", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, - "outputs": [], "source": [ - "# Clean-up the gpu's memory a bit to avoid Out-of-Memory errors\n", - "cyclegan = cyclegan.cpu()\n", - "torch.cuda.empty_cache()" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "483af604", + "id": "7f994579", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "## Evaluating the GAN\n", + "## Creating counterfactuals\n", "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "The data were saved in a directory called `test_images`.\n" + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "c59702f9", + "id": "4e4fe83e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "def make_dataset(directory):\n", - " \"\"\"Create a dataset from a directory of images with the classes in the same order as the VGG's output.\n", - "\n", - " Parameters\n", - " ----------\n", - " directory: str\n", - " The root directory of the images. It should contain sub-directories named after the classes, in which images are stored.\n", - " \"\"\"\n", - " # Make a dataset with the classes in the correct order\n", - " limited_classes = {k: v for k, v in class_to_idx.items() if k in classes}\n", - " dataset = ImageFolder(root=directory, transform=transform)\n", - " samples = ImageFolder.make_dataset(\n", - " directory, class_to_idx=limited_classes, extensions=\".png\"\n", - " )\n", - " # Sort samples by name\n", - " samples = sorted(samples, key=lambda s: s[0].split(\"_\")[-1])\n", - " dataset.classes = classes\n", - " dataset.class_to_idx = limited_classes\n", - " dataset.samples = samples\n", - " dataset.targets = [s[1] for s in samples]\n", - " return dataset" + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", + "prototypes = {}\n", + "\n", + "\n", + "for i in range(4):\n", + " options = np.where(test_mnist.conditions == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "c6bffc67", + "id": "049a6b22", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

    Task 4.1 Get the classifier accuracy on CycleGAN outputs

    \n", - "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", - "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", - "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
    " + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "42906ce7", - "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "639f37e2", + "metadata": {}, "outputs": [], "source": [ - "# Dataset of real images\n", - "ds_real = ...\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = ...\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = ..." + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "98131f0f", + "cell_type": "markdown", + "id": "02cb705b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "############################\n", - "# Solution to Task 4.2 #\n", - "############################\n", - "\n", - "# Dataset of real images\n", - "ds_real = make_dataset(\"test_images/real/\")\n", - "# Dataset of reconstructed images (full cycle)\n", - "ds_recon = make_dataset(\"test_images/reconstructed/\")\n", - "# Datset of counterfactuals (half-cycle)\n", - "ds_counterfactual = make_dataset(\"test_images/counterfactual/\")" + "Now we need to use these prototypes to create counterfactual images!" ] }, { "cell_type": "markdown", - "id": "c4500183", + "id": "f41a6ce5", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "
    \n", - "We get the following accuracies:\n", - "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", - "\n", - "

    Questions

    \n", + "

    Task 4: Create counterfactuals

    \n", + "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
    " + "
      \n", + "
    • Create a counterfactual image for each of the prototypes.
    • \n", + "
    • Classify the counterfactual image using the classifier.
    • \n", + "
    • Store the source and target labels; which is which?
    • \n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "17b2af0c", + "id": "00616e67", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "solution" + ] }, "outputs": [], "source": [ - "cf_pred, cf_gt = predict(ds_counterfactual, \"Counterfactuals\")\n", - "recon_pred, recon_gt = predict(ds_recon, \"Reconstructions\")\n", - "real_pred, real_gt = predict(ds_real, \"Real images\")\n", + "num_images = 1000\n", + "random_test_mnist = torch.utils.data.Subset(\n", + " test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)\n", + ")\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "# Get the accuracies\n", - "accuracy_real = accuracy_score(real_gt, real_pred)\n", - "accuracy_recon = accuracy_score(recon_gt, recon_pred)\n", - "accuracy_cf = accuracy_score(cf_gt, cf_pred)\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "print(\n", - " f\"Accuracy real: {accuracy_real}\\nAccuracy reconstruction: {accuracy_recon}\\nAccuracy counterfactuals: {accuracy_cf}\\n\"\n", - ")" + "for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):\n", + " for lbl in range(4):\n", + " # Create the counterfactual\n", + " x_fake = generator(\n", + " x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)\n", + " )\n", + " # Predict the class of the counterfactual image\n", + " pred = model(x_fake)\n", + "\n", + " # Store the source and target labels\n", + " source_labels.append(y) # The original label of the image\n", + " target_labels.append(lbl) # The desired label of the counterfactual image\n", + " # Store the counterfactual image and prediction\n", + " counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()\n", + " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "615c9449", + "id": "ebffc15f", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "4c0e1278", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "baac8071", + "metadata": {}, "outputs": [], "source": [ - "labels = [class_to_idx[i] for i in classes]\n", - "print(\"The confusion matrix of the classifier on the counterfactuals\")\n", - "cm_analysis(cf_pred, cf_gt, names=classes, labels=labels)" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")\n", + "plt.ylabel(\"True\")\n", + "plt.xlabel(\"Predicted\")\n", + "plt.show()" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "92401b45", + "cell_type": "markdown", + "id": "88e7ea0c", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "cm_analysis(real_pred, real_gt, names=classes, labels=labels)" + "

    Questions

    \n", + "
      \n", + "
    • How well is our GAN doing at creating counterfactual images?
    • \n", + "
    • Does your choice of prototypes matter? Why or why not?
    • \n", + "
    \n", + "
    " ] }, { "cell_type": "markdown", - "id": "57f8cca6", + "id": "25972c49", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "
    \n", - "

    Questions

    \n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", - "\n", - "
    " + "Let's also plot some examples of the counterfactual images." ] }, { - "cell_type": "markdown", - "id": "d81bbc95", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "code", + "execution_count": null, + "id": "12d49576", + "metadata": {}, + "outputs": [], "source": [ - "

    Checkpoint 4

    \n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", - "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
    " + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "406e8777", + "id": "8e6f04f3", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, "tags": [] }, "source": [ - "# Part 5: Highlighting Class-Relevant Differences" + "

    Questions

    \n", + "
      \n", + "
    • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
    • \n", + "
    • What is your hypothesis for the features that define each class?
    • \n", + "
    \n", + "
    " ] }, { "cell_type": "markdown", - "id": "69ee980b", - "metadata": {}, + "id": "50728ff2", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ "At this point we have:\n", - "- A classifier that can differentiate between neurotransmitters from EM images of synapses\n", - "- A vague idea of which parts of the images it thinks are important for this classification\n", - "- A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes\n", - "\n", - "What we don't know, is *how* the CycleGAN is modifying the images to change their class.\n", + "- A classifier that can differentiate between image of different classes\n", + "- A GAN that has correctly figured out how to change the class of an image\n", "\n", - "To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the \"real\" and \"fake\" images that are most important to change the decision of the classifier." + "Let's try putting the two together to see if we can figure out what exactly makes a class.\n" ] }, { - "cell_type": "markdown", - "id": "f7dbe347", + "cell_type": "code", + "execution_count": null, + "id": "dedc0f83", "metadata": {}, + "outputs": [], "source": [ - "

    Task 5.1 Get sucessfully converted samples

    \n", - "The CycleGAN is able to convert some, but not all images into their target types.\n", - "In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses:\n", - "
      \n", - "
    1. That were correctly classified originally
    2. \n", - "
    3. Whose counterfactuals were also correctly classified
    4. \n", - "
    \n", - "\n", - "TODO\n", - "- Get a boolean description of the `real` samples that were correctly predicted\n", - "- Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!)\n", - "- Get a boolean description of the `cf` samples that have the target class\n", - "
    " + "batch_size = 4\n", + "batch = [random_test_mnist[i] for i in range(batch_size)]\n", + "x = torch.stack([b[0] for b in batch])\n", + "y = torch.tensor([b[1] for b in batch])\n", + "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x = x.to(device).float()\n", + "y = y.to(device)\n", + "x_fake = x_fake.to(device).float()\n", + "\n", + "# Generated attributions on integrated gradients\n", + "attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)" ] }, { "cell_type": "code", "execution_count": null, - "id": "28ec78be", + "id": "5446e796", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "title": "Another visualization function" }, "outputs": [], "source": [ - "####### Task 5.1 TODO #######\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = ...\n", - "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = ...\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = ...\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", + "def visualize_color_attribution_and_counterfactual(\n", + " attribution, original_image, counterfactual_image\n", + "):\n", + " attribution = np.transpose(attribution, (1, 2, 0))\n", + " original_image = np.transpose(original_image, (1, 2, 0))\n", + " counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))\n", "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" + " fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))\n", + " ax0.imshow(original_image)\n", + " ax0.set_title(\"Image\")\n", + " ax0.axis(\"off\")\n", + " ax1.imshow(counterfactual_image)\n", + " ax1.set_title(\"Counterfactual\")\n", + " ax1.axis(\"off\")\n", + " ax2.imshow(np.abs(attribution))\n", + " ax2.set_title(\"Attribution\")\n", + " ax2.axis(\"off\")\n", + " plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "3f1391ba", + "id": "5e2fb59e", "metadata": { - "editable": true, - "lines_to_next_cell": 2, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "solution" - ] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "########################\n", - "# Solution to Task 5.1 #\n", - "########################\n", - "\n", - "# Get the samples where the real is correct\n", - "correct_real = real_pred == real_gt\n", - "\n", - "# HINT GABA is class 1 and ACh is class 0\n", - "target = 1 - real_gt\n", - "\n", - "# Get the samples where the counterfactual has reached the target\n", - "correct_cf = cf_pred == target\n", - "\n", - "# Successful conversions\n", - "success = np.where(np.logical_and(correct_real, correct_cf))[0]\n", - "\n", - "# Create datasets with only the successes\n", - "cf_success_ds = Subset(ds_counterfactual, success)\n", - "real_success_ds = Subset(ds_real, success)" + "for idx in range(batch_size):\n", + " print(\"Source class:\", y[idx].item())\n", + " print(\"Target class:\", 0)\n", + " visualize_color_attribution_and_counterfactual(\n", + " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", + " )" ] }, { "cell_type": "markdown", - "id": "5518deea", + "id": "b393a8f1", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:" + "

    Questions

    \n", + "
      \n", + "
    • Do the attributions explain the differences between the images and their counterfactuals?
    • \n", + "
    • What happens when the \"counterfactual\" and the original image are of the same class? Why do you think this is?
    • \n", + "
    • Do you have a more refined hypothesis for what makes each class unique?
    • \n", + "
    \n", + "
    " ] }, { - "cell_type": "code", - "execution_count": null, - "id": "c813f006", + "cell_type": "markdown", + "id": "5ba47fc6", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "model = model.to(\"cuda\")" + "

    Checkpoint 4

    \n", + "At this point you have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d599f126", + "cell_type": "markdown", + "id": "2654d788", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "real_true, real_pred = predict(real_success_ds, \"Real\")\n", - "cf_true, cf_pred = predict(cf_success_ds, \"Counterfactuals\")\n", + "# Part 5: Exploring the Style Space, finding the answer\n", + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", "\n", - "print(\n", - " \"Accuracy of the classifier on successful real images\",\n", - " accuracy_score(real_true, real_pred),\n", - ")\n", - "print(\n", - " \"Accuracy of the classifier on successful counterfactual images\",\n", - " accuracy_score(cf_true, cf_pred),\n", - ")" + "Here is an example of two images that are very similar in color, but are of different classes.\n", + "![same_color_diff_class](assets/same_color_diff_class.png)\n", + "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", + "\n", + "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", + "![same_class_diff_color](assets/same_class_diff_color.png)\n", + "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", + "\n", + "\n", + "So color is important... but not always? What's going on!?\n", + "There is a final piece of information that we can use to solve the puzzle: the style space." ] }, { "cell_type": "markdown", - "id": "877db1dc", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "76559366", + "metadata": {}, "source": [ - "### Creating hybrids from attributions\n", - "\n", - "Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.\n", - "If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid!\n", - "\n", - "To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification." + "

    Task 5.1: Explore the style space

    \n", + "Let's take a look at the style space.\n", + "We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "dcb7288f", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "f1fdb890", + "metadata": {}, "outputs": [], "source": [ - "dataloader_real = DataLoader(real_success_ds, batch_size=10)\n", - "dataloader_counter = DataLoader(cf_success_ds, batch_size=10)" + "from sklearn.decomposition import PCA\n", + "\n", + "\n", + "styles = []\n", + "labels = []\n", + "for img, label in random_test_mnist:\n", + " styles.append(\n", + " style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()\n", + " )\n", + " labels.append(label)\n", + "\n", + "# PCA\n", + "pca = PCA(n_components=2)\n", + "styles_pca = pca.fit_transform(styles)\n", + "\n", + "# Plot the PCA\n", + "markers = [\"o\", \"s\", \"P\", \"^\"]\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", + "plt.show()" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "95239b4b", + "cell_type": "markdown", + "id": "b666769e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "%%time\n", - "with torch.no_grad():\n", - " model.to(device)\n", - " # Create an integrated gradients object.\n", - " # integrated_gradients = IntegratedGradients(model)\n", - " # Generated attributions on integrated gradients\n", - " attributions = np.vstack(\n", - " [\n", - " integrated_gradients.attribute(\n", - " real.to(device),\n", - " target=target.to(device),\n", - " baselines=counterfactual.to(device),\n", - " )\n", - " .cpu()\n", - " .numpy()\n", - " for (real, target), (counterfactual, _) in zip(\n", - " dataloader_real, dataloader_counter\n", - " )\n", - " ]\n", - " )" + "

    Task 5.1: Adding color to the style space

    \n", + "We know that color is important. Does interpreting the style space as colors help us understand better?\n", + "\n", + "Let's use the style space to color the PCA plot.\n", + "(Note: there is no code to write here, just run the cell and answer the questions below)\n", + "
    " ] }, { "cell_type": "code", "execution_count": null, - "id": "8b968d7c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84835390", + "id": "e61d0c9b", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "# Functions for creating an interactive visualization of our attributions\n", - "model.cpu()\n", - "\n", - "import matplotlib\n", - "\n", - "cmap = matplotlib.cm.get_cmap(\"viridis\")\n", - "colors = cmap([0, 255])\n", - "\n", - "\n", - "@torch.no_grad()\n", - "def get_classifications(image, counter, hybrid):\n", - " model.eval()\n", - " class_idx = [full_dataset.classes.index(c) for c in classes]\n", - " tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float()\n", - " with torch.no_grad():\n", - " logits = model(tensor)[:, class_idx]\n", - " probs = torch.nn.Softmax(dim=1)(logits)\n", - " pred, counter_pred, hybrid_pred = probs\n", - " return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy()\n", - "\n", - "\n", - "def visualize_counterfactuals(idx, threshold=0.1):\n", - " image = real_success_ds[idx][0].numpy()\n", - " counter = cf_success_ds[idx][0].numpy()\n", - " mask = get_mask(attributions[idx], threshold)\n", - " hybrid = (1 - mask) * image + mask * counter\n", - " nan_mask = copy.deepcopy(mask)\n", - " nan_mask[nan_mask != 0] = 1\n", - " nan_mask[nan_mask == 0] = np.nan\n", - " # PLOT\n", - " fig, axes = plt.subplot_mosaic(\n", - " \"\"\"\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " mmm.ooo.ccc.hhh\n", - " ....ggg.fff.ppp\n", - " \"\"\",\n", - " figsize=(20, 5),\n", - " )\n", - " # Original\n", - " viz.visualize_image_attr(\n", - " np.transpose(mask, (1, 2, 0)),\n", - " np.transpose(image, (1, 2, 0)),\n", - " method=\"blended_heat_map\",\n", - " sign=\"absolute_value\",\n", - " show_colorbar=True,\n", - " title=\"Mask\",\n", - " use_pyplot=False,\n", - " plt_fig_axis=(fig, axes[\"m\"]),\n", + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=normalized_styles[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", " )\n", - " # Original\n", - " axes[\"o\"].imshow(image.squeeze(), cmap=\"gray\")\n", - " axes[\"o\"].set_title(\"Original\", fontsize=24)\n", - " # Counterfactual\n", - " axes[\"c\"].imshow(counter.squeeze(), cmap=\"gray\")\n", - " axes[\"c\"].set_title(\"Counterfactual\", fontsize=24)\n", - " # Hybrid\n", - " axes[\"h\"].imshow(hybrid.squeeze(), cmap=\"gray\")\n", - " axes[\"h\"].set_title(\"Hybrid\", fontsize=24)\n", - " # Mask\n", - " pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid)\n", - " axes[\"g\"].barh(classes, pred, color=colors)\n", - " axes[\"f\"].barh(classes, counter_pred, color=colors)\n", - " axes[\"p\"].barh(classes, hybrid_pred, color=colors)\n", - " for ix in [\"m\", \"o\", \"c\", \"h\"]:\n", - " axes[ix].axis(\"off\")\n", - "\n", - " for ix in [\"g\", \"f\", \"p\"]:\n", - " for tick in axes[ix].get_xticklabels():\n", - " tick.set_rotation(90)\n", - " axes[ix].set_xlim(0, 1)" + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", - "id": "c732d7a7", + "id": "6f1d3ff3", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "

    Task 5.2: Observing the effect of the changes on the classifier

    \n", - "Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.\n", - "At what point does it swap over?\n", - "\n", - "If you want to see different samples, slide through the `idx`.\n", - "
    " + "

    Questions

    \n", + "
      \n", + "
    • Do the colors match those that you have seen in the data?
    • \n", + "
    • Can you see any patterns in the colors? Is the space smooth, for example?
    • \n", + "
    " ] }, { - "cell_type": "code", - "execution_count": null, - "id": "23225866", + "cell_type": "markdown", + "id": "90889399", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, - "outputs": [], "source": [ - "interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))" + "

    Task 5.2: Using the images to color the style space

    \n", + "Finally, let's just use the colors from the images themselves!\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", + "\n", + "Let's get that color, then plot the style space again.\n", + "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", + "
    " ] }, { - "cell_type": "markdown", - "id": "1ca835c5", + "cell_type": "code", + "execution_count": null, + "id": "f67b3f90", "metadata": {}, + "outputs": [], "source": [ - "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." + "colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])\n", + "\n", + "# Plot the PCA again!\n", + "plt.figure(figsize=(10, 10))\n", + "for i in range(4):\n", + " plt.scatter(\n", + " styles_pca[np.array(labels) == i, 0],\n", + " styles_pca[np.array(labels) == i, 1],\n", + " c=colors[np.array(labels) == i],\n", + " marker=markers[i],\n", + " label=f\"Class {i}\",\n", + " )\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "771fb28f", + "id": "b18b2b81", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0 }, "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "bf87e80b", + "metadata": {}, "source": [ - "# Choose your own adventure\n", - "# idx = 0\n", - "# threshold = 0.1\n", - "\n", - "# # Plotting :)\n", - "# visualize_counterfactuals(idx, threshold)" + "

    Questions

    \n", + "
      \n", + "
    • Do the colors match those that you have seen in the data?
    • \n", + "
    • Can you see any patterns in the colors?
    • \n", + "
    • Can you guess what the classes correspond to?
    • " ] }, { "cell_type": "markdown", - "id": "3905e9a7", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "id": "11aafcc5", + "metadata": {}, "source": [ - "
      \n", - "

      Questions

      \n", + "

      Checkpoint 5

      \n", + "Congratulations! You have made it to the end of the exercise!\n", + "You have:\n", + "- Created a StarGAN that can change the class of an image\n", + "- Evaluated the StarGAN on unseen data\n", + "- Used the StarGAN to create counterfactual images\n", + "- Used the counterfactual images to highlight the differences between classes\n", + "- Used the style space to understand the differences between classes\n", "\n", - "- Can you find features that define either of the two classes?\n", - "- How consistent are they across the samples?\n", - "- Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)`\n", - "\n", - "Feel free to discuss your answers on the exercise chat!\n", - "
      " + "If you have any questions, feel free to ask them in the chat!\n", + "And check the Solutions exercise for a definite answer to how these classes are defined!" ] }, { "cell_type": "markdown", - "id": "578e5831", + "id": "a5c8b45e", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "lines_to_next_cell": 0, + "tags": [ + "solution" + ] }, "source": [ - "
      \n", - "

      The End.

      \n", - " Go forth and train some GANs!\n", - "
      " + "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", + "Check your style space again to see if you can see the patterns now!" ] }, { - "cell_type": "markdown", - "id": "2f8cb30e", + "cell_type": "code", + "execution_count": null, + "id": "45e17541", "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] + "tags": [ + "solution" + ] }, + "outputs": [], "source": [ - "## Going Further\n", + "# Let's plot the colormaps\n", + "import matplotlib as mpl\n", + "import numpy as np\n", + "\n", "\n", - "Here are some ideas for how to continue with this notebook:\n", + "def plot_color_gradients(cmap_list):\n", + " gradient = np.linspace(0, 1, 256)\n", + " gradient = np.vstack((gradient, gradient))\n", + "\n", + " # Create figure and adjust figure height to number of colormaps\n", + " nrows = len(cmap_list)\n", + " figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22\n", + " fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh))\n", + " fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99)\n", + "\n", + " for ax, name in zip(axs, cmap_list):\n", + " ax.imshow(gradient, aspect=\"auto\", cmap=mpl.colormaps[name])\n", + " ax.text(\n", + " -0.01,\n", + " 0.5,\n", + " name,\n", + " va=\"center\",\n", + " ha=\"right\",\n", + " fontsize=10,\n", + " transform=ax.transAxes,\n", + " )\n", "\n", - "1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.\n", - " * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).\n", - " * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.\n", - " * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.\n", - " * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved.\n", + " # Turn off *all* ticks & spines, not just the ones with colormaps.\n", + " for ax in axs:\n", + " ax.set_axis_off()\n", "\n", - "2. Explore the CycleGAN.\n", - " * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these?\n", "\n", - "3. Try on your own data!\n", - " * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code." + "plot_color_gradients([\"spring\", \"summer\", \"autumn\", \"winter\"])" ] } ], "metadata": { "jupytext": { - "cell_metadata_filter": "all" - }, - "kernelspec": { - "display_name": "09_knowledge_extraction", - "language": "python", - "name": "python3" + "cell_metadata_filter": "all", + "main_language": "python" } }, "nbformat": 4, diff --git a/solution.py b/solution.py new file mode 100644 index 0000000..cd92aa3 --- /dev/null +++ b/solution.py @@ -0,0 +1,1139 @@ +# %% [markdown] tags=[] +# # Exercise 8: Knowledge Extraction from a Pre-trained Neural Network +# +# The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on. +# +# We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course. +# Unlike regular MNIST, our dataset is classified not by number, but by color! +# +# We will: +# 1. Load a pre-trained classifier and try applying conventional attribution methods +# 2. Train a GAN to create counterfactual images - translating images from one class to another +# 3. Evaluate the GAN - see how good it is at fooling the classifier +# 4. Create attributions from the counterfactual, and learn the differences between the classes. +# +# If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem. +# ### Acknowledgments +# +# This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein. +# +# %% [markdown] +#
      +# Set your python kernel to 08_knowledge_extraction +#
      +# %% [markdown] +# +# # Part 1: Setup +# +# In this part of the notebook, we will load the same dataset as in the previous exercise. +# We will also learn to load one of our trained classifiers from a checkpoint. + +# %% +# loading the data +from classifier.data import ColoredMNIST + +mnist = ColoredMNIST("extras/data", download=True) +# %% [markdown] +# Some information about the dataset: +# - The dataset is a colored version of the MNIST dataset. +# - Instead of using the digits as classes, we use the colors. +# - There are four classes - the goal of the exercise is to find out what these are. +# +# Let's plot some examples +# %% +import matplotlib.pyplot as plt + +# Show some examples +fig, axs = plt.subplots(4, 4, figsize=(8, 8)) +for i, ax in enumerate(axs.flatten()): + x, y = mnist[i] + x = x.permute((1, 2, 0)) # make channels last + ax.imshow(x) + ax.set_title(f"Class {y}") + ax.axis("off") + +# %% [markdown] +# We have pre-traiend a classifier for you on this dataset. It is the same architecture classifier as you used in the Failure Modes exercise: a `DenseModel`. +# Let's load that classifier now! +# %% [markdown] +#

      Task 1.1: Load the classifier

      +# We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs: +# - `input_shape`: the shape of the input images, as a tuple +# - `num_classes`: the number of classes in the dataset +# +# Create a dense model with the right inputs and load the weights from the checkpoint. +#
      +# %% tags=["task"] +import torch +from classifier.model import DenseModel + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# TODO Load the model with the correct input shape +model = DenseModel(input_shape=(...), num_classes=4) + +# TODO modify this with the location of your classifier checkpoint +checkpoint = torch.load(...) +model.load_state_dict(checkpoint) +model = model.to(device) +# %% tags=["solution"] +import torch +from classifier.model import DenseModel + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# Load the model +model = DenseModel(input_shape=(3, 28, 28), num_classes=4) +# Load the checkpoint +checkpoint = torch.load("extras/checkpoints/model.pth") +model.load_state_dict(checkpoint) +model = model.to(device) + +# %% [markdown] +# Don't take my word for it! Let's see how well the classifier does on the test set. +# %% +from torch.utils.data import DataLoader +from sklearn.metrics import confusion_matrix +import seaborn as sns + +test_mnist = ColoredMNIST("extras/data", download=True, train=False) +dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False) + +labels = [] +predictions = [] +for x, y in dataloader: + pred = model(x.to(device)) + labels.extend(y.cpu().numpy()) + predictions.extend(pred.argmax(dim=1).cpu().numpy()) + +cm = confusion_matrix(labels, predictions, normalize="true") +sns.heatmap(cm, annot=True, fmt=".2f") +plt.ylabel("True") +plt.xlabel("Predicted") +plt.show() + +# %% [markdown] +# # Part 2: Using Integrated Gradients to find what the classifier knows +# +# In this section we will make a first attempt at highlighting differences between the "real" and "fake" images that are most important to change the decision of the classifier. +# + +# %% [markdown] +# ## Attributions through integrated gradients +# +# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change. +# +# Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients. + +# %% tags=[] +batch_size = 4 +batch = [] +for i in range(4): + batch.append(next(image for image in mnist if image[1] == i)) +x = torch.stack([b[0] for b in batch]) +y = torch.tensor([b[1] for b in batch]) +x = x.to(device) +y = y.to(device) + +# %% [markdown] tags=[] +#

      Task 2.1 Get an attribution

      +# +# In this next part, we will get attributions on single batch. We use a library called [captum](https://captum.ai), and focus on the `IntegratedGradients` method. +# Create an `IntegratedGradients` object and run attribution on `x,y` obtained above. +# +#
      + +# %% tags=["task"] +from captum.attr import IntegratedGradients + +############### Task 2.1 TODO ############ +# Create an integrated gradients object. +integrated_gradients = ... + +# Generated attributions on integrated gradients +attributions = ... + +# %% tags=["solution"] +######################### +# Solution for Task 2.1 # +######################### + +from captum.attr import IntegratedGradients + +# Create an integrated gradients object. +integrated_gradients = IntegratedGradients(model) + +# Generated attributions on integrated gradients +attributions = integrated_gradients.attribute(x, target=y) + +# %% tags=[] +attributions = ( + attributions.cpu().numpy() +) # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing + +# %% [markdown] tags=[] +# Here is an example for an image, and its corresponding attribution. + + +# %% tags=[] +from captum.attr import visualization as viz +import numpy as np + + +def visualize_attribution(attribution, original_image): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + + viz.visualize_image_attr_multiple( + attribution, + original_image, + methods=["original_image", "heat_map"], + signs=["all", "absolute_value"], + show_colorbar=True, + titles=["Image", "Attribution"], + use_pyplot=True, + ) + + +# %% tags=[] +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_attribution(attr, im) + +# %% [markdown] +# +# The attributions are shown as a heatmap. The brighter the pixel, the more important this attribution method thinks that it is. +# As you can see, it is pretty good at recognizing the number within the image. +# As we know, however, it is not the digit itself that is important for the classification, it is the color! +# Although the method is picking up really well on the region of interest, it would be difficult to conclude from this that it is the color that matters. + + +# %% [markdown] +# Something is slightly unfair about this visualization though. +# We are visualizing as if it were grayscale, but both our images and our attributions are in color! +# Can we learn more from the attributions if we visualize them in color? +# %% +def visualize_color_attribution(attribution, original_image): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + ax1.imshow(original_image) + ax1.set_title("Image") + ax1.axis("off") + ax2.imshow(np.abs(attribution)) + ax2.set_title("Attribution") + ax2.axis("off") + plt.show() + + +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_color_attribution(attr, im) + +# %% [markdown] +# We get some better clues when looking at the attributions in color. +# The highlighting doesn't just happen in the region with number, but also seems to hapen in a channel that matches the color of the image. +# Just based on this, however, we don't get much more information than we got from the images themselves. +# +# If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier. +# %% [markdown] +# +# ### Changing the baseline +# +# Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*. +# The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output. +# (For an interactive illustration of how the baseline affects the output, see [this Distill paper](https://distill.pub/2020/attribution-baselines/)) +# +# You can change the baseline used by the `integrated_gradients` object. +# +# Use the command: +# ``` +# ?integrated_gradients.attribute +# ``` +# To get more details about how to include the baseline. +# +# Try using the code below to change the baseline and see how this affects the output. +# +# 1. Random noise as a baseline +# 2. A blurred/noisy version of the original image as a baseline. + +# %% [markdown] +#

      Task 2.3: Use random noise as a baseline

      +# +# Hint: `torch.rand_like` +#
      + +# %% tags=["task"] +# Baseline +random_baselines = ... # TODO Change +# Generate the attributions +attributions_random = integrated_gradients.attribute(...) # TODO Change + +# Plotting +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_attribution(attr, im) + +# %% tags=["solution"] +######################### +# Solution for task 2.3 # +######################### +# Baseline +random_baselines = torch.rand_like(x) +# Generate the attributions +attributions_random = integrated_gradients.attribute( + x, target=y, baselines=random_baselines +) + +# Plotting +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_color_attribution(attr, im) + +# %% [markdown] tags=[] +#

      Task 2.4: Use a blurred image a baseline

      +# +# Hint: `torchvision.transforms.functional` has a useful function for this ;) +#
      + +# %% tags=["task"] +# TODO Import required function + +# Baseline +blurred_baselines = ... # TODO Create blurred version of the images +# Generate the attributions +attributions_blurred = integrated_gradients.attribute(...) # TODO Fill + +# Plotting +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_color_attribution(attr, im) + +# %% tags=["solution"] +######################### +# Solution for task 2.4 # +######################### +from torchvision.transforms.functional import gaussian_blur + +# Baseline +blurred_baselines = gaussian_blur(x, kernel_size=(5, 5)) +# Generate the attributions +attributions_blurred = integrated_gradients.attribute( + x, target=y, baselines=blurred_baselines +) + +# Plotting +for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()): + print(f"Class {lbl}") + visualize_color_attribution(attr, im) + +# %% [markdown] tags=[] +#

      Questions

      +#
        +#
      • What baseline do you like best so far? Why?
      • +#
      • Why do you think some baselines work better than others?
      • +#
      • If you were to design an ideal baseline, what would you choose?
      • +#
      +#
      + +# %% [markdown] +#

      BONUS Task: Using different attributions.

      +# +# +# [`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms. +# +# Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other? +#
      + +# %% [markdown] +#

      Checkpoint 2

      +# Let us know on the exercise chat when you've reached this point! +# +# At this point we have: +# +# - Loaded a classifier that classifies MNIST-like images by color, but we don't know how! +# - Tried applying Integrated Gradients to find out what the classifier is looking at - with little success. +# - Discovered the effect of changing the baseline on the output of integrated gradients. +# +# Coming up in the next section, we will learn how to create counterfactual images. +# These images will change *only what is necessary* in order to change the classification of the image. +# We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature. +#
      + +# %% [markdown] +# # Part 3: Train a GAN to Translate Images +# +# To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology. +# This method employs a StarGAN to translate images from one class to another to make counterfactual explanations. +# +# **What is a counterfactual?** +# +# You've learned about adversarial examples in the lecture on failure modes. These are the imperceptible or noisy changes to an image that drastically changes a classifier's opinion. +# Counterfactual explanations are the useful cousins of adversarial examples. They are *perceptible* and *informative* changes to an image that changes a classifier's opinion. +# +# In the image below you can see the difference between the two. In the first column are MNIST images along with their classifictaions, and in the second column are counterfactual explanations to *change* that class. You can see that in both cases a human being would (hopefully) agree with the new classification. By comparing the two columns, we can therefore begin to define what makes each digit special. +# +# In contrast, the third and fourth columns show an MNIST image and a corresponding adversarial example. Here the network returns a prediction that most human beings (who aren't being facetious) would strongly disagree with. +# +# +# +# **Counterfactual synapses** +# +# In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class. +# %% [markdown] tags=[] +# ### The model +# ![stargan.png](assets/stargan.png) +# +# In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020). +# It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y. +# +# We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks: +# - The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet` +# - The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel` +# - The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel` +# +# Let's start by creating these! +# %% +from dlmbl_unet import UNet +from torch import nn + + +class Generator(nn.Module): + + def __init__(self, generator, style_encoder): + super().__init__() + self.generator = generator + self.style_encoder = style_encoder + + def forward(self, x, y): + """ + x: torch.Tensor + The source image + y: torch.Tensor + The style image + """ + style = self.style_encoder(y) + # Concatenate the style vector with the input image + style = style.unsqueeze(-1).unsqueeze(-1) + style = style.expand(-1, -1, x.size(2), x.size(3)) + x = torch.cat([x, style], dim=1) + return self.generator(x) + + +# %% [markdown] +#

      Task 3.1: Create the models

      +# +# We are going to create the models for the generator, discriminator, and style mapping. +# +# Given the Generator structure above, fill in the missing parts for the unet and the style mapping. +# %% tags=["task"] +style_size = ... # TODO choose a size for the style space +unet_depth = ... # TODO Choose a depth for the UNet +style_encoder = DenseModel( + input_shape=..., num_classes=... # How big is the style space? +) +unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid()) + +generator = Generator(unet, style_encoder=style_encoder) +# %% tags=["solution"] +# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment. +# Choose your own setup to see what works for you. +style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3) +unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid()) +generator = Generator(unet, style_encoder=style_encoder) + +# %% [markdown] tags=[] +#

      Hyper-parameter choices

      +#
        +#
      • Are any of the hyperparameters you choose above constrained in some way?
      • +#
      • What would happen if you chose a depth of 10 for the UNet?
      • +#
      • Is there a minimum size for the style space? Why or why not?
      • +#
      + +# %% [markdown] tags=[] +#

      Task 3.2: Create the discriminator

      +# +# We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from. +# The discriminator will take as input either a real image or a fake image. +# Fill in the following code to create a discriminator that can classify the images into the correct number of classes. +#
      +# %% tags=["task"] +discriminator = DenseModel(input_shape=..., num_classes=...) +# %% tags=["solution"] +discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) +# %% [markdown] +# Let's move all models onto the GPU +# %% +generator = generator.to(device) +discriminator = discriminator.to(device) + +# %% [markdown] tags=[] +# ## Training a GAN +# +# Training an adversarial network is a bit more complicated than training a classifier. +# For starters, we are simultaneously training two different networks that work against each other. +# As such, we need to be careful about how and when we update the weights of each network. +# +# We will have two different optimizers, one for the Generator and one for the Discriminator. +# +# %% +optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5) +optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) +# %% [markdown] tags=[] +# +# There are also two different types of losses that we will need. +# **Adversarial loss** +# This loss describes how well the discriminator can tell the difference between real and generated images. +# In our case, this will be a sort of classification loss - we will use Cross Entropy. +#
      +# The adversarial loss will be applied differently to the generator and the discriminator! Be very careful! +#
      +# %% +adversarial_loss_fn = nn.CrossEntropyLoss() + +# %% [markdown] tags=[] +# +# **Cycle/reconstruction loss** +# The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input! +# Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes. +# The cycle loss is applied only to the generator. +# +# %% +cycle_loss_fn = nn.L1Loss() + +# %% [markdown] tags=[] +# To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`. +# %% +from torch.utils.data import DataLoader + +dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True +) # We will use the same dataset as before + + +# %% [markdown] tags=[] +# As we stated earlier, it is important to make sure when each network is being trained when working with a GAN. +# Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing! +# `set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`). +# %% +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + + +# %% [markdown] tags=[] +# Another consequence of adversarial training is that it is very unstable. +# While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model. +# To force some stability back into the training, we will use Exponential Moving Averages (EMA). +# +# In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update. +# A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period. +# Each epoch, we will then copy the EMA model's weights back to the generator. +# This is a common technique used in GAN training to stabilize the training process. +# Pay attention to what this does to the loss during the training process! +# %% +from copy import deepcopy + + +def exponential_moving_average(model, ema_model, beta=0.999): + """Update the EMA model's parameters with an exponential moving average""" + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(beta).add_((1 - beta) * param.data) + + +def copy_parameters(source_model, target_model): + """Copy the parameters of a model to another model""" + for param, target_param in zip( + source_model.parameters(), target_model.parameters() + ): + target_param.data.copy_(param.data) + + +# %% +generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder)) +generator_ema = generator_ema.to(device) + +# %% [markdown] tags=[] +#

      Task 3.3: Training!

      +# You were given several different options in the training code below. In each case, one of the options will work, and the other will not. +# Comment out the option that you think will not work. +#
        +#
      • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
      • +#
      • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
      • +#
      • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
      • +# .
      • Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.
      • +#
      +# Let's train the StarGAN one batch a time. +# While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take? +#
      +# %% [markdown] tags=[] +# Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋 +# %% tags=["task"] +from tqdm import tqdm # This is a nice library for showing progress bars + + +losses = {"cycle": [], "adv": [], "disc": []} + +for epoch in range(15): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + # + optimizer_d.zero_grad() + # + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + + # TODO - Choose an option by commenting out what you don't want + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + # 2. make sure the discriminator can tell fake is fake + ############ + # Option 1 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) + ############ + # Option 2 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + + # EMA update + # TODO - perform the EMA update + ############ + # Option 1 # + ############ + exponential_moving_average(generator, generator_ema) + ############ + # Option 2 # + ############ + exponential_moving_average(generator_ema, generator) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) +# %% tags=["solution"] +from tqdm import tqdm # This is a nice library for showing progress bars + + +losses = {"cycle": [], "adv": [], "disc": []} +for epoch in range(15): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() + + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator + + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) + + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_d.zero_grad() + # + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + real_loss = adversarial_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can tell fake is fake + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + exponential_moving_average(generator, generator_ema) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) + + +# %% [markdown] tags=[] +# Once training is complete, we can plot the losses to see how well the model is doing. +# %% +fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) +ax1.plot(losses["cycle"]) +ax1.set_title("Cycle loss") +ax2.plot(losses["adv"]) +ax2.set_title("Adversarial loss") +ax3.plot(losses["disc"]) +ax3.set_title("Discriminator loss") +plt.show() + +# %% [markdown] tags=[] +#

      Questions

      +#
        +#
      • Do the losses look like what you expected?
      • +#
      • How do these losses differ from the losses you would expect from a classifier?
      • +#
      • Based only on the losses, do you think the model is doing well?
      • +#
      + +# %% [markdown] tags=[] +# We can also look at some examples of the images that the generator is creating. +# %% +idx = 0 +fig, axs = plt.subplots(1, 4, figsize=(12, 4)) +axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[0].set_title("Input image") +axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[1].set_title("Style image") +axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[2].set_title("Generated image") +axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy()) +axs[3].set_title("Cycled image") + +for ax in axs: + ax.axis("off") +plt.show() + +# %% +# %% [markdown] tags=[] +#

      Checkpoint 3

      +# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training. +# The same method can be used to create a StarGAN with different basic elements. +# For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future. +# +# You know the drill... let us know on the exercise chat when you have arrived here! +#
      + +# %% [markdown] tags=[] +# # Part 4: Evaluating the GAN and creating Counterfactuals + +# %% [markdown] tags=[] +# ## Creating counterfactuals +# +# The first thing that we want to do is make sure that our GAN is able to create counterfactual images. +# To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly. +# +# First, let's get the test dataset, so we can evaluate the GAN on unseen data. +# Then, let's get four prototypical images from the dataset as style sources. + +# %% Loading the test dataset +test_mnist = ColoredMNIST("extras/data", download=True, train=False) +prototypes = {} + + +for i in range(4): + options = np.where(test_mnist.conditions == i)[0] + # Note that you can change the image index if you want to use a different prototype. + image_index = 0 + x, y = test_mnist[options[image_index]] + prototypes[i] = x + +# %% [markdown] tags=[] +# Let's have a look at the prototypes. +# %% +fig, axs = plt.subplots(1, 4, figsize=(12, 4)) +for i, ax in enumerate(axs): + ax.imshow(prototypes[i].permute(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Prototype {i}") + +# %% [markdown] +# Now we need to use these prototypes to create counterfactual images! +# %% [markdown] +#

      Task 4: Create counterfactuals

      +# In the below, we will store the counterfactual images in the `counterfactuals` array. +# +#
        +#
      • Create a counterfactual image for each of the prototypes.
      • +#
      • Classify the counterfactual image using the classifier.
      • +#
      • Store the source and target labels; which is which?
      • +#
      +# %% tags=["task"] +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # TODO Create the counterfactual + x_fake = generator(x.unsqueeze(0).to(device), ...) + # TODO Predict the class of the counterfactual image + pred = model(...) + + # TODO Store the source and target labels + source_labels.append(...) # The original label of the image + target_labels.append(...) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() + predictions.append(pred.argmax().item()) +# %% tags=["solution"] +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # Create the counterfactual + x_fake = generator( + x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device) + ) + # Predict the class of the counterfactual image + pred = model(x_fake) + + # Store the source and target labels + source_labels.append(y) # The original label of the image + target_labels.append(lbl) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() + predictions.append(pred.argmax().item()) + +# %% [markdown] tags=[] +# Let's plot the confusion matrix for the counterfactual images. +# %% +cf_cm = confusion_matrix(target_labels, predictions, normalize="true") +sns.heatmap(cf_cm, annot=True, fmt=".2f") +plt.ylabel("True") +plt.xlabel("Predicted") +plt.show() + +# %% [markdown] tags=[] +#

      Questions

      +#
        +#
      • How well is our GAN doing at creating counterfactual images?
      • +#
      • Does your choice of prototypes matter? Why or why not?
      • +#
      +#
      + +# %% [markdown] tags=[] +# Let's also plot some examples of the counterfactual images. + +# %% +for i in np.random.choice(range(num_images), 4): + fig, axs = plt.subplots(1, 4, figsize=(20, 4)) + for j, ax in enumerate(axs): + ax.imshow(counterfactuals[j][i].transpose(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Class {j}") + +# %% [markdown] tags=[] +#

      Questions

      +#
        +#
      • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
      • +#
      • What is your hypothesis for the features that define each class?
      • +#
      +#
      + +# %% [markdown] +# At this point we have: +# - A classifier that can differentiate between image of different classes +# - A GAN that has correctly figured out how to change the class of an image +# +# Let's try putting the two together to see if we can figure out what exactly makes a class. +# +# %% +batch_size = 4 +batch = [random_test_mnist[i] for i in range(batch_size)] +x = torch.stack([b[0] for b in batch]) +y = torch.tensor([b[1] for b in batch]) +x_fake = torch.tensor(counterfactuals[0, :batch_size]) +x = x.to(device).float() +y = y.to(device) +x_fake = x_fake.to(device).float() + +# Generated attributions on integrated gradients +attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) + + +# %% Another visualization function +def visualize_color_attribution_and_counterfactual( + attribution, original_image, counterfactual_image +): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0)) + + fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5)) + ax0.imshow(original_image) + ax0.set_title("Image") + ax0.axis("off") + ax1.imshow(counterfactual_image) + ax1.set_title("Counterfactual") + ax1.axis("off") + ax2.imshow(np.abs(attribution)) + ax2.set_title("Attribution") + ax2.axis("off") + plt.show() + + +# %% +for idx in range(batch_size): + print("Source class:", y[idx].item()) + print("Target class:", 0) + visualize_color_attribution_and_counterfactual( + attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy() + ) +# %% [markdown] +#

      Questions

      +#
        +#
      • Do the attributions explain the differences between the images and their counterfactuals?
      • +#
      • What happens when the "counterfactual" and the original image are of the same class? Why do you think this is?
      • +#
      • Do you have a more refined hypothesis for what makes each class unique?
      • +#
      +#
      +# %% [markdown] +#

      Checkpoint 4

      +# At this point you have: +# - Created a StarGAN that can change the class of an image +# - Evaluated the StarGAN on unseen data +# - Used the StarGAN to create counterfactual images +# - Used the counterfactual images to highlight the differences between classes +# +# %% [markdown] +# # Part 5: Exploring the Style Space, finding the answer +# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes! +# +# Here is an example of two images that are very similar in color, but are of different classes. +# ![same_color_diff_class](assets/same_color_diff_class.png) +# While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it! +# +# Conversely, here is an example of two images with very different colors, but that are of the same class: +# ![same_class_diff_color](assets/same_class_diff_color.png) +# Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all! +# +# +# So color is important... but not always? What's going on!? +# There is a final piece of information that we can use to solve the puzzle: the style space. +# %% [markdown] +#

      Task 5.1: Explore the style space

      +# Let's take a look at the style space. +# We will use the style encoder to encode the style of the images and then use PCA to visualize it. +#
      + +# %% +from sklearn.decomposition import PCA + + +styles = [] +labels = [] +for img, label in random_test_mnist: + styles.append( + style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze() + ) + labels.append(label) + +# PCA +pca = PCA(n_components=2) +styles_pca = pca.fit_transform(styles) + +# Plot the PCA +markers = ["o", "s", "P", "^"] +plt.figure(figsize=(10, 10)) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + marker=markers[i], + label=f"Class {i}", + ) +plt.legend() +plt.show() + +# %% [markdown] +#

      Task 5.1: Adding color to the style space

      +# We know that color is important. Does interpreting the style space as colors help us understand better? +# +# Let's use the style space to color the PCA plot. +# (Note: there is no code to write here, just run the cell and answer the questions below) +#
      +# %% +styles = np.array(styles) +normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp( + styles, axis=1, keepdims=True +) + +# Plot the PCA again! +plt.figure(figsize=(10, 10)) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + c=normalized_styles[np.array(labels) == i], + marker=markers[i], + label=f"Class {i}", + ) +plt.legend() +plt.show() +# %% [markdown] +#

      Questions

      +#
        +#
      • Do the colors match those that you have seen in the data?
      • +#
      • Can you see any patterns in the colors? Is the space smooth, for example?
      • +#
      +# %% [markdown] +#

      Task 5.2: Using the images to color the style space

      +# Finally, let's just use the colors from the images themselves! +# The maximum value in the image (since they are "black-and-color") can be used as a color! +# +# Let's get that color, then plot the style space again. +# (Note: once again, no coding needed here, just run the cell and think about the results with the questions below) +#
      +# %% +colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]) + +# Plot the PCA again! +plt.figure(figsize=(10, 10)) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + c=colors[np.array(labels) == i], + marker=markers[i], + label=f"Class {i}", + ) +plt.legend() +plt.show() + +# %% +# %% [markdown] +#

      Questions

      +#
        +#
      • Do the colors match those that you have seen in the data?
      • +#
      • Can you see any patterns in the colors?
      • +#
      • Can you guess what the classes correspond to?
      • + +# %% [markdown] +#

        Checkpoint 5

        +# Congratulations! You have made it to the end of the exercise! +# You have: +# - Created a StarGAN that can change the class of an image +# - Evaluated the StarGAN on unseen data +# - Used the StarGAN to create counterfactual images +# - Used the counterfactual images to highlight the differences between classes +# - Used the style space to understand the differences between classes +# +# If you have any questions, feel free to ask them in the chat! +# And check the Solutions exercise for a definite answer to how these classes are defined! + +# %% [markdown] tags=["solution"] +# The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter. +# Check your style space again to see if you can see the patterns now! +# %% tags=["solution"] +# Let's plot the colormaps +import matplotlib as mpl +import numpy as np + + +def plot_color_gradients(cmap_list): + gradient = np.linspace(0, 1, 256) + gradient = np.vstack((gradient, gradient)) + + # Create figure and adjust figure height to number of colormaps + nrows = len(cmap_list) + figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 + fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh)) + fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99) + + for ax, name in zip(axs, cmap_list): + ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name]) + ax.text( + -0.01, + 0.5, + name, + va="center", + ha="right", + fontsize=10, + transform=ax.transAxes, + ) + + # Turn off *all* ticks & spines, not just the ones with colormaps. + for ax in axs: + ax.set_axis_off() + + +plot_color_gradients(["spring", "summer", "autumn", "winter"])