diff --git a/solution.py b/solution.py
index dab7902..542ddc8 100644
--- a/solution.py
+++ b/solution.py
@@ -90,6 +90,27 @@
model.load_state_dict(checkpoint)
model = model.to(device)
+# %% [markdown]
+# Don't take my word for it! Let's see how well the classifier does on the test set.
+# %%
+from torch.utils.data import DataLoader
+from sklearn.metrics import confusion_matrix
+import seaborn as sns
+
+test_mnist = ColoredMNIST("data", download=True, train=False)
+dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)
+
+labels = []
+predictions = []
+for x, y in dataloader:
+ pred = model(x.to(device))
+ labels.extend(y.cpu().numpy())
+ predictions.extend(pred.argmax(dim=1).cpu().numpy())
+
+cm = confusion_matrix(labels, predictions, normalize="true")
+sns.heatmap(cm, annot=True, fmt=".2f")
+
+
# %% [markdown]
# # Part 2: Using Integrated Gradients to find what the classifier knows
#
@@ -675,128 +696,130 @@ def set_requires_grad(module, value=True):
plt.show()
# %%
-# TODO wip here
# %% [markdown] tags=[]
#
Checkpoint 3
-# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.
-# The same method can be used to create a CycleGAN with different basic elements.
+# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.
+# The same method can be used to create a StarGAN with different basic elements.
# For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.
#
-# You know the drill... let us know on the exercise chat!
+# You know the drill... let us know on the exercise chat when you have arrived here!
#
# %% [markdown] tags=[]
# # Part 4: Evaluating the GAN
# %% [markdown] tags=[]
+# ## Creating counterfactuals
#
-# ## That was fun!... let's load a pre-trained model
+# The first thing that we want to do is make sure that our GAN is able to create counterfactual images.
+# To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.
#
-# Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).
-#
-# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset.
+# First, let's get the test dataset, so we can evaluate the GAN on unseen data.
+# Then, let's get four prototypical images from the dataset as style sources.
-# %% tags=[]
-from pathlib import Path
-import torch
-
-# TODO load the pre-trained model
+# %% Loading the test dataset
+test_mnist = ColoredMNIST("data", download=True, train=False)
+prototypes = {}
-# %% [markdown] tags=[]
-# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?
-# %% tags=[]
-# TODO show some examples
+for i in range(4):
+ options = np.where(test_mnist.targets == i)[0]
+ # Note that you can change the image index if you want to use a different prototype.
+ image_index = 0
+ x, y = test_mnist[options[image_index]]
+ prototypes[i] = x
# %% [markdown] tags=[]
-# We're going to apply the GAN to our test dataset.
+# Let's have a look at the prototypes.
+# %%
+fig, axs = plt.subplots(1, 4, figsize=(12, 4))
+for i, ax in enumerate(axs):
+ ax.imshow(prototypes[i].permute(1, 2, 0))
+ ax.axis("off")
+ ax.set_title(f"Prototype {i}")
-# %% tags=[]
-# TODO load the test dataset
+# %% [markdown]
+# Now we need to use these prototypes to create counterfactual images!
+# TODO make a task here!
+# %%
+num_images = len(test_mnist)
+counterfactuals = np.zeros((4, num_images, 3, 28, 28))
+
+predictions = []
+source_labels = []
+target_labels = []
+
+for x, y in test_mnist:
+ for i in range(4):
+ if i == y:
+ # Store the image as is.
+ counterfactuals[i] = ...
+ # Create the counterfactual from the image and prototype
+ x_fake = generator(x.unsqueeze(0).to(device), ...)
+ counterfactuals[i] = x_fake.cpu().detach().numpy()
+ pred = model(...)
+
+ source_labels.append(y)
+ target_labels.append(i)
+ predictions.append(pred.argmax().item())
-# %% [markdown] tags=[]
-# ## Evaluating the GAN
-#
-# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.
-# We will do this by running the classifier that we trained earlier on generated data.
-#
+# %% tags=["solution"]
+num_images = len(test_mnist)
+counterfactuals = np.zeros((4, num_images, 3, 28, 28))
+
+predictions = []
+source_labels = []
+target_labels = []
+
+for x, y in test_mnist:
+ for i in range(4):
+ if i == y:
+ # Store the image as is.
+ counterfactuals[i] = x
+ # Create the counterfactual
+ x_fake = generator(
+ x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)
+ )
+ counterfactuals[i] = x_fake.cpu().detach().numpy()
+ pred = model(x_fake)
+
+ source_labels.append(y)
+ target_labels.append(i)
+ predictions.append(pred.argmax().item())
# %% [markdown] tags=[]
-# Task 4.1 Get the classifier accuracy on CycleGAN outputs
-#
-# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!
-#
-# The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.
-#
-# TODO
-# - Use the `make_dataset` function to create a dataset for the three different image types that we saved above
-# - real
-# - reconstructed
-# - counterfactual
-#
+# Let's plot the confusion matrix for the counterfactual images.
+# %%
+cf_cm = confusion_matrix(target_labels, predictions, normalize="true")
+sns.heatmap(cf_cm, annot=True, fmt=".2f")
# %% [markdown] tags=[]
-#
-# We get the following accuracies:
-#
-# 1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN
-# 2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.
-# 3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.
-#
-#
Questions
-#
-# - In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?
-# - How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?
-#
-# Let us know your insights on the exercise chat.
+#
Questions
+#
+# - How well is our GAN doing at creating counterfactual images?
+# - Do you think that the prototypes used matter? Why or why not?
+#
#
-# %%
-# TODO make a loop on the data that creates the counterfactual images, given a set of options as input
-counterfactuals, reconstructions, targets, labels = ...
-
-
-# %% [markwodn]
-# Evaluate the images
-# %%
-# TODO use the loaded classifier to evaluate the images
-# Get the accuracies
-def predict():
- # TODO return predictions, labels
- pass
-
# %% [markdown] tags=[]
-# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images.
+# Let's also plot some examples of the counterfactual images.
-# %%
-print("The confusion matrix on the real images... for comparison")
-# TODO Confusion matrix on the counterfactual images
-confusion_matrix = ...
-# TODO plot
-# %%
-print("The confusion matrix on the real images... for comparison")
-# TODO Confusion matrix on the real images, for comparison
-confusion_matrix = ...
-# TODO plot
-
-# %% [markdown]
-#
-#
Questions
-#
-# - What would you expect the confusion matrix for the counterfactuals to look like? Why?
-# - Do the two directions of the CycleGAN work equally as well?
-# - Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?
-#
-#
+for i in np.random.choice(range(num_images), 4):
+ fig, axs = plt.subplots(1, 4, figsize=(20, 4))
+ for j, ax in enumerate(axs):
+ ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))
+ ax.axis("off")
+ ax.set_title(f"Class {j}")
-# %% [markdown]
-#
Checkpoint 4
-# We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!
-# Take the time to think about the questions above before moving on...
-#
-# This is the end of Section 4. Let us know on the exercise chat if you have reached this point!
+# %% [markdown] tags=[]
+#
Questions
+#
+# - Can you easily tell which of these images is the original, and which ones are the counterfactuals?
+# - What is your hypothesis for the features that define each class?
+#
#
+# TODO wip here
# %% [markdown]
# # Part 5: Highlighting Class-Relevant Differences