diff --git a/solution.py b/solution.py index dab7902..542ddc8 100644 --- a/solution.py +++ b/solution.py @@ -90,6 +90,27 @@ model.load_state_dict(checkpoint) model = model.to(device) +# %% [markdown] +# Don't take my word for it! Let's see how well the classifier does on the test set. +# %% +from torch.utils.data import DataLoader +from sklearn.metrics import confusion_matrix +import seaborn as sns + +test_mnist = ColoredMNIST("data", download=True, train=False) +dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False) + +labels = [] +predictions = [] +for x, y in dataloader: + pred = model(x.to(device)) + labels.extend(y.cpu().numpy()) + predictions.extend(pred.argmax(dim=1).cpu().numpy()) + +cm = confusion_matrix(labels, predictions, normalize="true") +sns.heatmap(cm, annot=True, fmt=".2f") + + # %% [markdown] # # Part 2: Using Integrated Gradients to find what the classifier knows # @@ -675,128 +696,130 @@ def set_requires_grad(module, value=True): plt.show() # %% -# TODO wip here # %% [markdown] tags=[] #

Checkpoint 3

-# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training. -# The same method can be used to create a CycleGAN with different basic elements. +# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training. +# The same method can be used to create a StarGAN with different basic elements. # For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future. # -# You know the drill... let us know on the exercise chat! +# You know the drill... let us know on the exercise chat when you have arrived here! #
# %% [markdown] tags=[] # # Part 4: Evaluating the GAN # %% [markdown] tags=[] +# ## Creating counterfactuals # -# ## That was fun!... let's load a pre-trained model +# The first thing that we want to do is make sure that our GAN is able to create counterfactual images. +# To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly. # -# Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...). -# -# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset. +# First, let's get the test dataset, so we can evaluate the GAN on unseen data. +# Then, let's get four prototypical images from the dataset as style sources. -# %% tags=[] -from pathlib import Path -import torch - -# TODO load the pre-trained model +# %% Loading the test dataset +test_mnist = ColoredMNIST("data", download=True, train=False) +prototypes = {} -# %% [markdown] tags=[] -# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction? -# %% tags=[] -# TODO show some examples +for i in range(4): + options = np.where(test_mnist.targets == i)[0] + # Note that you can change the image index if you want to use a different prototype. + image_index = 0 + x, y = test_mnist[options[image_index]] + prototypes[i] = x # %% [markdown] tags=[] -# We're going to apply the GAN to our test dataset. +# Let's have a look at the prototypes. +# %% +fig, axs = plt.subplots(1, 4, figsize=(12, 4)) +for i, ax in enumerate(axs): + ax.imshow(prototypes[i].permute(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Prototype {i}") -# %% tags=[] -# TODO load the test dataset +# %% [markdown] +# Now we need to use these prototypes to create counterfactual images! +# TODO make a task here! +# %% +num_images = len(test_mnist) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for x, y in test_mnist: + for i in range(4): + if i == y: + # Store the image as is. + counterfactuals[i] = ... + # Create the counterfactual from the image and prototype + x_fake = generator(x.unsqueeze(0).to(device), ...) + counterfactuals[i] = x_fake.cpu().detach().numpy() + pred = model(...) + + source_labels.append(y) + target_labels.append(i) + predictions.append(pred.argmax().item()) -# %% [markdown] tags=[] -# ## Evaluating the GAN -# -# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another. -# We will do this by running the classifier that we trained earlier on generated data. -# +# %% tags=["solution"] +num_images = len(test_mnist) +counterfactuals = np.zeros((4, num_images, 3, 28, 28)) + +predictions = [] +source_labels = [] +target_labels = [] + +for x, y in test_mnist: + for i in range(4): + if i == y: + # Store the image as is. + counterfactuals[i] = x + # Create the counterfactual + x_fake = generator( + x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device) + ) + counterfactuals[i] = x_fake.cpu().detach().numpy() + pred = model(x_fake) + + source_labels.append(y) + target_labels.append(i) + predictions.append(pred.argmax().item()) # %% [markdown] tags=[] -#

Task 4.1 Get the classifier accuracy on CycleGAN outputs

-# -# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class! -# -# The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized. -# -# TODO -# - Use the `make_dataset` function to create a dataset for the three different image types that we saved above -# - real -# - reconstructed -# - counterfactual -#
+# Let's plot the confusion matrix for the counterfactual images. +# %% +cf_cm = confusion_matrix(target_labels, predictions, normalize="true") +sns.heatmap(cf_cm, annot=True, fmt=".2f") # %% [markdown] tags=[] -#
-# We get the following accuracies: -# -# 1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN -# 2. `accuracy_recon`: Accuracy of the classifier on the reconstruction. -# 3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images. -# -#

Questions

-# -# - In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower? -# - How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why? -# -# Let us know your insights on the exercise chat. +#

Questions

+# #
-# %% -# TODO make a loop on the data that creates the counterfactual images, given a set of options as input -counterfactuals, reconstructions, targets, labels = ... - - -# %% [markwodn] -# Evaluate the images -# %% -# TODO use the loaded classifier to evaluate the images -# Get the accuracies -def predict(): - # TODO return predictions, labels - pass - # %% [markdown] tags=[] -# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images. +# Let's also plot some examples of the counterfactual images. -# %% -print("The confusion matrix on the real images... for comparison") -# TODO Confusion matrix on the counterfactual images -confusion_matrix = ... -# TODO plot -# %% -print("The confusion matrix on the real images... for comparison") -# TODO Confusion matrix on the real images, for comparison -confusion_matrix = ... -# TODO plot - -# %% [markdown] -#
-#

Questions

-# -# - What would you expect the confusion matrix for the counterfactuals to look like? Why? -# - Do the two directions of the CycleGAN work equally as well? -# - Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other? -# -#
+for i in np.random.choice(range(num_images), 4): + fig, axs = plt.subplots(1, 4, figsize=(20, 4)) + for j, ax in enumerate(axs): + ax.imshow(counterfactuals[j][i].transpose(1, 2, 0)) + ax.axis("off") + ax.set_title(f"Class {j}") -# %% [markdown] -#

Checkpoint 4

-# We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for! -# Take the time to think about the questions above before moving on... -# -# This is the end of Section 4. Let us know on the exercise chat if you have reached this point! +# %% [markdown] tags=[] +#

Questions

+#
    +#
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • +#
  • What is your hypothesis for the features that define each class?
  • +#
#
+# TODO wip here # %% [markdown] # # Part 5: Highlighting Class-Relevant Differences