Skip to content

Commit

Permalink
wip: Begin evaluation of the counterfactuals using classifier
Browse files Browse the repository at this point in the history
adjavon committed Aug 12, 2024
1 parent b3d267d commit 3327629
Showing 1 changed file with 115 additions and 92 deletions.
207 changes: 115 additions & 92 deletions solution.py
Original file line number Diff line number Diff line change
@@ -90,6 +90,27 @@
model.load_state_dict(checkpoint)
model = model.to(device)

# %% [markdown]
# Don't take my word for it! Let's see how well the classifier does on the test set.
# %%
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import seaborn as sns

test_mnist = ColoredMNIST("data", download=True, train=False)
dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)

labels = []
predictions = []
for x, y in dataloader:
pred = model(x.to(device))
labels.extend(y.cpu().numpy())
predictions.extend(pred.argmax(dim=1).cpu().numpy())

cm = confusion_matrix(labels, predictions, normalize="true")
sns.heatmap(cm, annot=True, fmt=".2f")


# %% [markdown]
# # Part 2: Using Integrated Gradients to find what the classifier knows
#
@@ -675,128 +696,130 @@ def set_requires_grad(module, value=True):
plt.show()

# %%
# TODO wip here
# %% [markdown] tags=[]
# <div class="alert alert-block alert-success"><h2>Checkpoint 3</h2>
# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.
# The same method can be used to create a CycleGAN with different basic elements.
# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.
# The same method can be used to create a StarGAN with different basic elements.
# For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.
#
# You know the drill... let us know on the exercise chat!
# You know the drill... let us know on the exercise chat when you have arrived here!
# </div>

# %% [markdown] tags=[]
# # Part 4: Evaluating the GAN

# %% [markdown] tags=[]
# ## Creating counterfactuals
#
# ## That was fun!... let's load a pre-trained model
# The first thing that we want to do is make sure that our GAN is able to create counterfactual images.
# To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.
#
# Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).
#
# To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset.
# First, let's get the test dataset, so we can evaluate the GAN on unseen data.
# Then, let's get four prototypical images from the dataset as style sources.

# %% tags=[]
from pathlib import Path
import torch

# TODO load the pre-trained model
# %% Loading the test dataset
test_mnist = ColoredMNIST("data", download=True, train=False)
prototypes = {}

# %% [markdown] tags=[]
# Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?

# %% tags=[]
# TODO show some examples
for i in range(4):
options = np.where(test_mnist.targets == i)[0]
# Note that you can change the image index if you want to use a different prototype.
image_index = 0
x, y = test_mnist[options[image_index]]
prototypes[i] = x

# %% [markdown] tags=[]
# We're going to apply the GAN to our test dataset.
# Let's have a look at the prototypes.
# %%
fig, axs = plt.subplots(1, 4, figsize=(12, 4))
for i, ax in enumerate(axs):
ax.imshow(prototypes[i].permute(1, 2, 0))
ax.axis("off")
ax.set_title(f"Prototype {i}")

# %% tags=[]
# TODO load the test dataset
# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
# TODO make a task here!
# %%
num_images = len(test_mnist)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))

predictions = []
source_labels = []
target_labels = []

for x, y in test_mnist:
for i in range(4):
if i == y:
# Store the image as is.
counterfactuals[i] = ...
# Create the counterfactual from the image and prototype
x_fake = generator(x.unsqueeze(0).to(device), ...)
counterfactuals[i] = x_fake.cpu().detach().numpy()
pred = model(...)

source_labels.append(y)
target_labels.append(i)
predictions.append(pred.argmax().item())

# %% [markdown] tags=[]
# ## Evaluating the GAN
#
# The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.
# We will do this by running the classifier that we trained earlier on generated data.
#
# %% tags=["solution"]
num_images = len(test_mnist)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))

predictions = []
source_labels = []
target_labels = []

for x, y in test_mnist:
for i in range(4):
if i == y:
# Store the image as is.
counterfactuals[i] = x
# Create the counterfactual
x_fake = generator(
x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)
)
counterfactuals[i] = x_fake.cpu().detach().numpy()
pred = model(x_fake)

source_labels.append(y)
target_labels.append(i)
predictions.append(pred.argmax().item())

# %% [markdown] tags=[]
# <div class="alert alert-block alert-info"><h3>Task 4.1 Get the classifier accuracy on CycleGAN outputs</h3>
#
# Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!
#
# The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.
#
# TODO
# - Use the `make_dataset` function to create a dataset for the three different image types that we saved above
# - real
# - reconstructed
# - counterfactual
# </div>
# Let's plot the confusion matrix for the counterfactual images.
# %%
cf_cm = confusion_matrix(target_labels, predictions, normalize="true")
sns.heatmap(cf_cm, annot=True, fmt=".2f")

# %% [markdown] tags=[]
# <div class="alert alert-banner alert-warning">
# We get the following accuracies:
#
# 1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN
# 2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.
# 3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.
#
# <h3>Questions</h3>
#
# - In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?
# - How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?
#
# Let us know your insights on the exercise chat.
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> How well is our GAN doing at creating counterfactual images? </li>
# <li> Do you think that the prototypes used matter? Why or why not? </li>
# </ul>
# </div>
# %%
# TODO make a loop on the data that creates the counterfactual images, given a set of options as input
counterfactuals, reconstructions, targets, labels = ...


# %% [markwodn]
# Evaluate the images
# %%
# TODO use the loaded classifier to evaluate the images
# Get the accuracies
def predict():
# TODO return predictions, labels
pass


# %% [markdown] tags=[]
# We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images.
# Let's also plot some examples of the counterfactual images.

# %%
print("The confusion matrix on the real images... for comparison")
# TODO Confusion matrix on the counterfactual images
confusion_matrix = ...
# TODO plot
# %%
print("The confusion matrix on the real images... for comparison")
# TODO Confusion matrix on the real images, for comparison
confusion_matrix = ...
# TODO plot

# %% [markdown]
# <div class="alert alert-banner alert-warning">
# <h3>Questions</h3>
#
# - What would you expect the confusion matrix for the counterfactuals to look like? Why?
# - Do the two directions of the CycleGAN work equally as well?
# - Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?
#
# </div>
for i in np.random.choice(range(num_images), 4):
fig, axs = plt.subplots(1, 4, figsize=(20, 4))
for j, ax in enumerate(axs):
ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))
ax.axis("off")
ax.set_title(f"Class {j}")

# %% [markdown]
# <div class="alert alert-block alert-success"><h2>Checkpoint 4</h2>
# We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!
# Take the time to think about the questions above before moving on...
#
# This is the end of Section 4. Let us know on the exercise chat if you have reached this point!
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info"><h3>Questions</h3>
# <ul>
# <li>Can you easily tell which of these images is the original, and which ones are the counterfactuals?</li>
# <li>What is your hypothesis for the features that define each class?</li>
# </ul>
# </div>

# TODO wip here
# %% [markdown]
# # Part 5: Highlighting Class-Relevant Differences

0 comments on commit 3327629

Please sign in to comment.