diff --git a/solution.py b/solution.py
index 66f5fb0..b45dbbf 100644
--- a/solution.py
+++ b/solution.py
@@ -1,8 +1,8 @@
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
-# # Exercise 8: Knowledge Extraction from a Convolutional Neural Network
+# %% [markdown] tags=[]
+# # Exercise 8: Knowledge Extraction from a Pre-trained Neural Network
#
# The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.
-
+#
# We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.
# Unlike regular MNIST, our dataset is classified not by number, but by color!
#
@@ -21,23 +21,25 @@
#
Task 1.1: Load the classifier
# We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:
@@ -60,7 +63,7 @@
#
# Create a dense model with the right inputs and load the weights from the checkpoint.
#
-# %%
+# %% tags=["task"]
import torch
from classifier.model import DenseModel
@@ -100,7 +103,7 @@
#
# Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
batch_size = 4
batch = [mnist[i] for i in range(batch_size)]
x = torch.stack([b[0] for b in batch])
@@ -108,7 +111,7 @@
x = x.to(device)
y = y.to(device)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Task 2.1 Get an attribution
#
# In this next part, we will get attributions on single batch. We use a library called [captum](https://captum.ai), and focus on the `IntegratedGradients` method.
@@ -116,7 +119,7 @@
#
#
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=["task"]
from captum.attr import IntegratedGradients
############### Task 2.1 TODO ############
@@ -126,7 +129,7 @@
# Generated attributions on integrated gradients
attributions = ...
-# %% editable=true slideshow={"slide_type": ""} tags=["solution"]
+# %% tags=["solution"]
#########################
# Solution for Task 2.1 #
#########################
@@ -139,16 +142,16 @@
# Generated attributions on integrated gradients
attributions = integrated_gradients.attribute(x, target=y)
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
attributions = (
attributions.cpu().numpy()
) # Move the attributions from the GPU to the CPU, and turn then into numpy arrays for future processing
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# Here is an example for an image, and its corresponding attribution.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
from captum.attr import visualization as viz
import numpy as np
@@ -168,7 +171,7 @@ def visualize_attribution(attribution, original_image):
)
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
for attr, im in zip(attributions, x.cpu().numpy()):
visualize_attribution(attr, im)
@@ -235,7 +238,7 @@ def visualize_color_attribution(attribution, original_image):
# Hint: `torch.rand_like`
#
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=["task"]
# Baseline
random_baselines = ... # TODO Change
# Generate the attributions
@@ -245,7 +248,7 @@ def visualize_color_attribution(attribution, original_image):
for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):
visualize_attribution(attr, im)
-# %% editable=true slideshow={"slide_type": ""} tags=["solution"]
+# %% tags=["solution"]
#########################
# Solution for task 2.3 #
#########################
@@ -260,13 +263,13 @@ def visualize_color_attribution(attribution, original_image):
for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):
visualize_color_attribution(attr, im)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Task 2.4: Use a blurred image a baseline
#
# Hint: `torchvision.transforms.functional` has a useful function for this ;)
#
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=["task"]
# TODO Import required function
# Baseline
@@ -278,7 +281,7 @@ def visualize_color_attribution(attribution, original_image):
for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):
visualize_color_attribution(attr, im)
-# %% editable=true slideshow={"slide_type": ""} tags=["solution"]
+# %% tags=["solution"]
#########################
# Solution for task 2.4 #
#########################
@@ -295,12 +298,13 @@ def visualize_color_attribution(attribution, original_image):
for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):
visualize_color_attribution(attr, im)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Questions
-# TODO change these questions now!!
-# - Are any of the features consistent across baselines? Why do you think that is?
-# - What baseline do you like best so far? Why?
-# - If you were to design an ideal baseline, what would you choose?
+#
+# - What baseline do you like best so far? Why?
+# - Why do you think some baselines work better than others?
+# - If you were to design an ideal baseline, what would you choose?
+#
#
# %% [markdown]
@@ -327,7 +331,6 @@ def visualize_color_attribution(attribution, original_image):
# We'll see that using counterfactuals we will be able to disambiguate between color and number as an important feature.
#
-
# %% [markdown]
# # Part 3: Train a GAN to Translate Images
#
@@ -348,7 +351,7 @@ def visualize_color_attribution(attribution, original_image):
# **Counterfactual synapses**
#
# In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class.
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ### The model
# ![cycle.png](assets/cyclegan.png)
#
@@ -399,13 +402,13 @@ def forward(self, x, y):
unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())
generator = Generator(unet, style_mapping=style_mapping)
-# %% tags = ["solution"]
+# %% tags=["solution"]
# Here is an example of a working exercise
style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)
unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())
generator = Generator(unet, style_mapping=style_mapping)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# Task 3.2: Create the discriminator
#
# We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.
@@ -422,7 +425,7 @@ def forward(self, x, y):
generator = generator.to(device)
discriminator = discriminator.to(device)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ## Training a GAN
#
# Yes, really!
@@ -432,7 +435,7 @@ def forward(self, x, y):
# - A cycle loss
# TODO add exercise!
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Task 3.2: Training!
# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.
#
@@ -440,83 +443,83 @@ def forward(self, x, y):
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ...this time again.
#
#
-
+#
# TODO also turn this into a standalong script for use during the project phase
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-
-def set_requires_grad(module, value=True):
- """Sets `requires_grad` on a `module`'s parameters to `value`"""
- for param in module.parameters():
- param.requires_grad = value
-
-
-cycle_loss_fn = nn.L1Loss()
-class_loss_fn = nn.CrossEntropyLoss()
-
-optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)
-optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
-
-dataloader = DataLoader(
- mnist, batch_size=32, drop_last=True, shuffle=True
-) # We will use the same dataset as before
-
-losses = {"cycle": [], "adv": [], "disc": []}
-for epoch in range(50):
- for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
- x = x.to(device)
- y = y.to(device)
- # get the target y by shuffling the classes
- # get the style sources by random sampling
- random_index = torch.randperm(len(y))
- x_style = x[random_index].clone()
- y_target = y[random_index].clone()
-
- set_requires_grad(generator, True)
- set_requires_grad(discriminator, False)
- optimizer_g.zero_grad()
- # Get the fake image
- x_fake = generator(x, x_style)
- # Try to cycle back
- x_cycled = generator(x_fake, x)
- # Discriminate
- discriminator_x_fake = discriminator(x_fake)
- # Losses to train the generator
-
- # 1. make sure the image can be reconstructed
- cycle_loss = cycle_loss_fn(x, x_cycled)
- # 2. make sure the discriminator is fooled
- adv_loss = class_loss_fn(discriminator_x_fake, y_target)
-
- # Optimize the generator
- (cycle_loss + adv_loss).backward()
- optimizer_g.step()
-
- set_requires_grad(generator, False)
- set_requires_grad(discriminator, True)
- optimizer_d.zero_grad()
- # TODO Do I need to re-do the forward pass?
- discriminator_x = discriminator(x)
- discriminator_x_fake = discriminator(x_fake.detach())
- # Losses to train the discriminator
- # 1. make sure the discriminator can tell real is real
- real_loss = class_loss_fn(discriminator_x, y)
- # 2. make sure the discriminator can't tell fake is fake
- fake_loss = -class_loss_fn(discriminator_x_fake, y_target)
- #
- disc_loss = (real_loss + fake_loss) * 0.5
- disc_loss.backward()
- # Optimize the discriminator
- optimizer_d.step()
-
- losses["cycle"].append(cycle_loss.item())
- losses["adv"].append(adv_loss.item())
- losses["disc"].append(disc_loss.item())
+# from torch.utils.data import DataLoader
+# from tqdm import tqdm
+#
+#
+# def set_requires_grad(module, value=True):
+# """Sets `requires_grad` on a `module`'s parameters to `value`"""
+# for param in module.parameters():
+# param.requires_grad = value
+#
+#
+# cycle_loss_fn = nn.L1Loss()
+# class_loss_fn = nn.CrossEntropyLoss()
+#
+# optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)
+# optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
+#
+# dataloader = DataLoader(
+# mnist, batch_size=32, drop_last=True, shuffle=True
+# ) # We will use the same dataset as before
+#
+# losses = {"cycle": [], "adv": [], "disc": []}
+# for epoch in range(50):
+# for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
+# x = x.to(device)
+# y = y.to(device)
+# # get the target y by shuffling the classes
+# # get the style sources by random sampling
+# random_index = torch.randperm(len(y))
+# x_style = x[random_index].clone()
+# y_target = y[random_index].clone()
+#
+# set_requires_grad(generator, True)
+# set_requires_grad(discriminator, False)
+# optimizer_g.zero_grad()
+# # Get the fake image
+# x_fake = generator(x, x_style)
+# # Try to cycle back
+# x_cycled = generator(x_fake, x)
+# # Discriminate
+# discriminator_x_fake = discriminator(x_fake)
+# # Losses to train the generator
+#
+# # 1. make sure the image can be reconstructed
+# cycle_loss = cycle_loss_fn(x, x_cycled)
+# # 2. make sure the discriminator is fooled
+# adv_loss = class_loss_fn(discriminator_x_fake, y_target)
+#
+# # Optimize the generator
+# (cycle_loss + adv_loss).backward()
+# optimizer_g.step()
+#
+# set_requires_grad(generator, False)
+# set_requires_grad(discriminator, True)
+# optimizer_d.zero_grad()
+# # TODO Do I need to re-do the forward pass?
+# discriminator_x = discriminator(x)
+# discriminator_x_fake = discriminator(x_fake.detach())
+# # Losses to train the discriminator
+# # 1. make sure the discriminator can tell real is real
+# real_loss = class_loss_fn(discriminator_x, y)
+# # 2. make sure the discriminator can't tell fake is fake
+# fake_loss = -class_loss_fn(discriminator_x_fake, y_target)
+# #
+# disc_loss = (real_loss + fake_loss) * 0.5
+# disc_loss.backward()
+# # Optimize the discriminator
+# optimizer_d.step()
+#
+# losses["cycle"].append(cycle_loss.item())
+# losses["adv"].append(adv_loss.item())
+# losses["disc"].append(disc_loss.item())
# %%
plt.plot(losses["cycle"], label="Cycle loss")
@@ -524,7 +527,7 @@ def set_requires_grad(module, value=True):
plt.plot(losses["disc"], label="Discriminator loss")
plt.legend()
plt.show()
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# Let's add a quick plotting function before we begin training...
# %%
@@ -541,7 +544,7 @@ def set_requires_grad(module, value=True):
# TODO WIP here
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Checkpoint 3
# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.
# The same method can be used to create a CycleGAN with different basic elements.
@@ -550,10 +553,10 @@ def set_requires_grad(module, value=True):
# You know the drill... let us know on the exercise chat!
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# # Part 4: Evaluating the GAN
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
# ## That was fun!... let's load a pre-trained model
#
@@ -561,32 +564,32 @@ def set_requires_grad(module, value=True):
#
# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
from pathlib import Path
import torch
# TODO load the pre-trained model
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
# TODO show some examples
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# We're going to apply the GAN to our test dataset.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
# TODO load the test dataset
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ## Evaluating the GAN
#
# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.
# We will do this by running the classifier that we trained earlier on generated data.
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Task 4.1 Get the classifier accuracy on CycleGAN outputs
#
# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!
@@ -600,7 +603,7 @@ def set_requires_grad(module, value=True):
# - counterfactual
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
# We get the following accuracies:
#
@@ -630,7 +633,7 @@ def predict():
pass
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images.
# %%
@@ -690,7 +693,7 @@ def predict():
# - Get a boolean description of the `cf` samples that have the target class
#
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
####### Task 5.1 TODO #######
# Get the samples where the real is correct
@@ -710,7 +713,7 @@ def predict():
real_success_ds = Subset(ds_real, success)
-# %% editable=true slideshow={"slide_type": ""} tags=["solution"]
+# %% tags=["solution"]
########################
# Solution to Task 5.1 #
########################
@@ -732,13 +735,13 @@ def predict():
real_success_ds = Subset(ds_real, success)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
model = model.to("cuda")
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
real_true, real_pred = predict(real_success_ds, "Real")
cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals")
@@ -751,7 +754,7 @@ def predict():
accuracy_score(cf_true, cf_pred),
)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ### Creating hybrids from attributions
#
# Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.
@@ -759,11 +762,11 @@ def predict():
#
# To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
dataloader_real = DataLoader(real_success_ds, batch_size=10)
dataloader_counter = DataLoader(cf_success_ds, batch_size=10)
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
# %%time
with torch.no_grad():
model.to(device)
@@ -787,7 +790,7 @@ def predict():
# %%
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
# Functions for creating an interactive visualization of our attributions
model.cpu()
@@ -861,7 +864,7 @@ def visualize_counterfactuals(idx, threshold=0.1):
axes[ix].set_xlim(0, 1)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
Task 5.2: Observing the effect of the changes on the classifier
# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.
# At what point does it swap over?
@@ -869,13 +872,13 @@ def visualize_counterfactuals(idx, threshold=0.1):
# If you want to see different samples, slide through the `idx`.
#
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))
# %% [markdown]
# HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out.
-# %% editable=true slideshow={"slide_type": ""} tags=[]
+# %% tags=[]
# Choose your own adventure
# idx = 0
# threshold = 0.1
@@ -883,7 +886,7 @@ def visualize_counterfactuals(idx, threshold=0.1):
# # Plotting :)
# visualize_counterfactuals(idx, threshold)
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
#
Questions
#
@@ -894,13 +897,13 @@ def visualize_counterfactuals(idx, threshold=0.1):
# Feel free to discuss your answers on the exercise chat!
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
#
#
The End.
# Go forth and train some GANs!
#
-# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
+# %% [markdown] tags=[]
# ## Going Further
#
# Here are some ideas for how to continue with this notebook: