Skip to content

Commit

Permalink
wip: Add EMA to GAN training
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Aug 12, 2024
1 parent 03e6aaa commit 5e963df
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 234 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ ipykernel
tqdm
captum
git+https://github.com/dlmbl/dlmbl-unet.git
scikit-learn
seaborn
287 changes: 53 additions & 234 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,36 @@ def forward(self, x, y):

# %% [markdown] tags=[]
# TODO - Describe set_requires_grad


# %%
def set_requires_grad(module, value=True):
"""Sets `requires_grad` on a `module`'s parameters to `value`"""
for param in module.parameters():
param.requires_grad = value

# %% [markdown] tags=[]
# TODO - Describe EMA

# %%
from copy import deepcopy


def exponential_moving_average(model, ema_model, beta=0.999):
"""Update the EMA model's parameters with an exponential moving average"""
for param, ema_param in zip(model.parameters(), ema_model.parameters()):
ema_param.data.mul_(beta).add_((1 - beta) * param.data)


def copy_parameters(source_model, target_model):
"""Copy the parameters of a model to another model"""
for param, target_param in zip(
source_model.parameters(), target_model.parameters()
):
target_param.data.copy_(param.data)


# %%
generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))
generator_ema = generator_ema.to(device)

# %% [markdown] tags=[]
# <div class="alert alert-banner alert-info"><h4>Task 3.2: Training!</h4>
Expand Down Expand Up @@ -613,6 +635,18 @@ def set_requires_grad(module, value=True):
losses["adv"].append(adv_loss.item())
losses["disc"].append(disc_loss.item())

# EMA update
# TODO - perform the EMA update
############
# Option 1 #
############
exponential_moving_average(generator, generator_ema)
############
# Option 2 #
############
exponential_moving_average(generator_ema, generator)
# Copy the EMA model's parameters to the generator
copy_parameters(generator_ema, generator)
# %% tags=["solution"]
from tqdm import tqdm # This is a nice library for showing progress bars

Expand Down Expand Up @@ -651,7 +685,7 @@ def set_requires_grad(module, value=True):
set_requires_grad(generator, False)
set_requires_grad(discriminator, True)
optimizer_d.zero_grad()
# TODO Do I need to re-do the forward pass?
#
discriminator_x = discriminator(x)
discriminator_x_fake = discriminator(x_fake.detach())
# Losses to train the discriminator
Expand All @@ -668,6 +702,9 @@ def set_requires_grad(module, value=True):
losses["cycle"].append(cycle_loss.item())
losses["adv"].append(adv_loss.item())
losses["disc"].append(disc_loss.item())
exponential_moving_average(generator, generator_ema)
# Copy the EMA model's parameters to the generator
copy_parameters(generator_ema, generator)


# %% [markdown] tags=[]
Expand All @@ -681,6 +718,14 @@ def set_requires_grad(module, value=True):
plt.legend()
plt.show()

# %% [markdown] tags=[]
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> Do the losses look like what you expected? </li>
# <li> How do these losses differ from the losses you would expect from a classifier? </li>
# <li> Based only on the losses, do you think the model is doing well? </li>
# </ul>

# %% [markdown] tags=[]
# We can also look at some examples of the images that the generator is creating.
# %%
Expand Down Expand Up @@ -741,7 +786,7 @@ def set_requires_grad(module, value=True):
# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
# TODO make a task here!
# %%
# %% tags=["task"]
num_images = len(test_mnist)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))

Expand Down Expand Up @@ -819,246 +864,20 @@ def set_requires_grad(module, value=True):
# </ul>
# </div>

# TODO wip here
# %% [markdown]
# # Part 5: Highlighting Class-Relevant Differences

# %% [markdown]
# At this point we have:
# - A classifier that can differentiate between neurotransmitters from EM images of synapses
# - A vague idea of which parts of the images it thinks are important for this classification
# - A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes
#
# What we don't know, is *how* the CycleGAN is modifying the images to change their class.
# - A classifier that can differentiate between image of different classes
# - A GAN that has correctly figured out how to change the class of an image
#
# To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier.

# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 5.1 Get sucessfully converted samples</h3>
# The CycleGAN is able to convert some, but not all images into their target types.
# In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses:
# <ol>
# <li> That were correctly classified originally</li>
# <li>Whose counterfactuals were also correctly classified</li>
# </ol>
# Let's try putting the two together to see if we can figure out what exactly makes a class.
#
# TODO
# - Get a boolean description of the `real` samples that were correctly predicted
# - Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!)
# - Get a boolean description of the `cf` samples that have the target class
# </div>

# %% tags=[]
####### Task 5.1 TODO #######

# Get the samples where the real is correct
correct_real = ...

# HINT GABA is class 1 and ACh is class 0
target = ...

# Get the samples where the counterfactual has reached the target
correct_cf = ...

# Successful conversions
success = np.where(np.logical_and(correct_real, correct_cf))[0]

# Create datasets with only the successes
cf_success_ds = Subset(ds_counterfactual, success)
real_success_ds = Subset(ds_real, success)


# %% tags=["solution"]
########################
# Solution to Task 5.1 #
########################

# Get the samples where the real is correct
correct_real = real_pred == real_gt

# HINT GABA is class 1 and ACh is class 0
target = 1 - real_gt

# Get the samples where the counterfactual has reached the target
correct_cf = cf_pred == target

# Successful conversions
success = np.where(np.logical_and(correct_real, correct_cf))[0]

# Create datasets with only the successes
cf_success_ds = Subset(ds_counterfactual, success)
real_success_ds = Subset(ds_real, success)


# %% [markdown] tags=[]
# To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples:

# %% tags=[]
model = model.to("cuda")

# %% tags=[]
real_true, real_pred = predict(real_success_ds, "Real")
cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals")

print(
"Accuracy of the classifier on successful real images",
accuracy_score(real_true, real_pred),
)
print(
"Accuracy of the classifier on successful counterfactual images",
accuracy_score(cf_true, cf_pred),
)

# %% [markdown] tags=[]
# ### Creating hybrids from attributions
#
# Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution.
# If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid!
#
# To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification.

# %% tags=[]
dataloader_real = DataLoader(real_success_ds, batch_size=10)
dataloader_counter = DataLoader(cf_success_ds, batch_size=10)

# %% tags=[]
# %%time
with torch.no_grad():
model.to(device)
# Create an integrated gradients object.
# integrated_gradients = IntegratedGradients(model)
# Generated attributions on integrated gradients
attributions = np.vstack(
[
integrated_gradients.attribute(
real.to(device),
target=target.to(device),
baselines=counterfactual.to(device),
)
.cpu()
.numpy()
for (real, target), (counterfactual, _) in zip(
dataloader_real, dataloader_counter
)
]
)

# %%

# %% tags=[]
# Functions for creating an interactive visualization of our attributions
model.cpu()

import matplotlib

cmap = matplotlib.cm.get_cmap("viridis")
colors = cmap([0, 255])


@torch.no_grad()
def get_classifications(image, counter, hybrid):
model.eval()
class_idx = [full_dataset.classes.index(c) for c in classes]
tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float()
with torch.no_grad():
logits = model(tensor)[:, class_idx]
probs = torch.nn.Softmax(dim=1)(logits)
pred, counter_pred, hybrid_pred = probs
return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy()


def visualize_counterfactuals(idx, threshold=0.1):
image = real_success_ds[idx][0].numpy()
counter = cf_success_ds[idx][0].numpy()
mask = get_mask(attributions[idx], threshold)
hybrid = (1 - mask) * image + mask * counter
nan_mask = copy.deepcopy(mask)
nan_mask[nan_mask != 0] = 1
nan_mask[nan_mask == 0] = np.nan
# PLOT
fig, axes = plt.subplot_mosaic(
"""
mmm.ooo.ccc.hhh
mmm.ooo.ccc.hhh
mmm.ooo.ccc.hhh
....ggg.fff.ppp
""",
figsize=(20, 5),
)
# Original
viz.visualize_image_attr(
np.transpose(mask, (1, 2, 0)),
np.transpose(image, (1, 2, 0)),
method="blended_heat_map",
sign="absolute_value",
show_colorbar=True,
title="Mask",
use_pyplot=False,
plt_fig_axis=(fig, axes["m"]),
)
# Original
axes["o"].imshow(image.squeeze(), cmap="gray")
axes["o"].set_title("Original", fontsize=24)
# Counterfactual
axes["c"].imshow(counter.squeeze(), cmap="gray")
axes["c"].set_title("Counterfactual", fontsize=24)
# Hybrid
axes["h"].imshow(hybrid.squeeze(), cmap="gray")
axes["h"].set_title("Hybrid", fontsize=24)
# Mask
pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid)
axes["g"].barh(classes, pred, color=colors)
axes["f"].barh(classes, counter_pred, color=colors)
axes["p"].barh(classes, hybrid_pred, color=colors)
for ix in ["m", "o", "c", "h"]:
axes[ix].axis("off")

for ix in ["g", "f", "p"]:
for tick in axes[ix].get_xticklabels():
tick.set_rotation(90)
axes[ix].set_xlim(0, 1)


# %% [markdown] tags=[]
# <div class="alert alert-block alert-info"><h3>Task 5.2: Observing the effect of the changes on the classifier</h3>
# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes.
# At what point does it swap over?
#
# If you want to see different samples, slide through the `idx`.
# </div>

# %% tags=[]
interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05))

# %% [markdown]
# HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out.

# %% tags=[]
# Choose your own adventure
# idx = 0
# threshold = 0.1

# # Plotting :)
# visualize_counterfactuals(idx, threshold)

# %% [markdown] tags=[]
# <div class="alert alert-warning">
# <h4>Questions</h4>
#
# - Can you find features that define either of the two classes?
# - How consistent are they across the samples?
# - Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)`
#
# Feel free to discuss your answers on the exercise chat!
# </div>

# %% [markdown] tags=[]
# <div class="alert alert-block alert-success">
# <h1>The End.</h1>
# Go forth and train some GANs!
# </div>

# %% [markdown] tags=[]
# TODO
# ## Going Further
#
# Here are some ideas for how to continue with this notebook:
Expand Down

0 comments on commit 5e963df

Please sign in to comment.