From 5e963dfab93b6dc457bfbd850431abd84588d777 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 17:06:53 -0400 Subject: [PATCH] wip: Add EMA to GAN training --- requirements.txt | 2 + solution.py | 287 +++++++++-------------------------------------- 2 files changed, 55 insertions(+), 234 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7f2c196..b57e7e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ ipykernel tqdm captum git+https://github.com/dlmbl/dlmbl-unet.git +scikit-learn +seaborn \ No newline at end of file diff --git a/solution.py b/solution.py index 542ddc8..9edd9f7 100644 --- a/solution.py +++ b/solution.py @@ -505,14 +505,36 @@ def forward(self, x, y): # %% [markdown] tags=[] # TODO - Describe set_requires_grad - - # %% def set_requires_grad(module, value=True): """Sets `requires_grad` on a `module`'s parameters to `value`""" for param in module.parameters(): param.requires_grad = value +# %% [markdown] tags=[] +# TODO - Describe EMA + +# %% +from copy import deepcopy + + +def exponential_moving_average(model, ema_model, beta=0.999): + """Update the EMA model's parameters with an exponential moving average""" + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(beta).add_((1 - beta) * param.data) + + +def copy_parameters(source_model, target_model): + """Copy the parameters of a model to another model""" + for param, target_param in zip( + source_model.parameters(), target_model.parameters() + ): + target_param.data.copy_(param.data) + + +# %% +generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping)) +generator_ema = generator_ema.to(device) # %% [markdown] tags=[] #

Task 3.2: Training!

@@ -613,6 +635,18 @@ def set_requires_grad(module, value=True): losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + # EMA update + # TODO - perform the EMA update + ############ + # Option 1 # + ############ + exponential_moving_average(generator, generator_ema) + ############ + # Option 2 # + ############ + exponential_moving_average(generator_ema, generator) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) # %% tags=["solution"] from tqdm import tqdm # This is a nice library for showing progress bars @@ -651,7 +685,7 @@ def set_requires_grad(module, value=True): set_requires_grad(generator, False) set_requires_grad(discriminator, True) optimizer_d.zero_grad() - # TODO Do I need to re-do the forward pass? + # discriminator_x = discriminator(x) discriminator_x_fake = discriminator(x_fake.detach()) # Losses to train the discriminator @@ -668,6 +702,9 @@ def set_requires_grad(module, value=True): losses["cycle"].append(cycle_loss.item()) losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + exponential_moving_average(generator, generator_ema) + # Copy the EMA model's parameters to the generator + copy_parameters(generator_ema, generator) # %% [markdown] tags=[] @@ -681,6 +718,14 @@ def set_requires_grad(module, value=True): plt.legend() plt.show() +# %% [markdown] tags=[] +#

Questions

+#
    +#
  • Do the losses look like what you expected?
  • +#
  • How do these losses differ from the losses you would expect from a classifier?
  • +#
  • Based only on the losses, do you think the model is doing well?
  • +#
+ # %% [markdown] tags=[] # We can also look at some examples of the images that the generator is creating. # %% @@ -741,7 +786,7 @@ def set_requires_grad(module, value=True): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! # TODO make a task here! -# %% +# %% tags=["task"] num_images = len(test_mnist) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) @@ -819,246 +864,20 @@ def set_requires_grad(module, value=True): # #
-# TODO wip here # %% [markdown] # # Part 5: Highlighting Class-Relevant Differences # %% [markdown] # At this point we have: -# - A classifier that can differentiate between neurotransmitters from EM images of synapses -# - A vague idea of which parts of the images it thinks are important for this classification -# - A CycleGAN that is sometimes able to trick the classifier with barely perceptible changes -# -# What we don't know, is *how* the CycleGAN is modifying the images to change their class. +# - A classifier that can differentiate between image of different classes +# - A GAN that has correctly figured out how to change the class of an image # -# To start to answer this question, we will use a [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412) method to highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier. - -# %% [markdown] -#

Task 5.1 Get sucessfully converted samples

-# The CycleGAN is able to convert some, but not all images into their target types. -# In order to observe and highlight useful differences, we want to observe our attribution method at work only on those examples of synapses: -#
    -#
  1. That were correctly classified originally
  2. -#
  3. Whose counterfactuals were also correctly classified
  4. -#
+# Let's try putting the two together to see if we can figure out what exactly makes a class. # -# TODO -# - Get a boolean description of the `real` samples that were correctly predicted -# - Get the target class for the `counterfactual` images (Hint: It isn't `cf_gt`!) -# - Get a boolean description of the `cf` samples that have the target class -#
- -# %% tags=[] -####### Task 5.1 TODO ####### - -# Get the samples where the real is correct -correct_real = ... - -# HINT GABA is class 1 and ACh is class 0 -target = ... - -# Get the samples where the counterfactual has reached the target -correct_cf = ... - -# Successful conversions -success = np.where(np.logical_and(correct_real, correct_cf))[0] - -# Create datasets with only the successes -cf_success_ds = Subset(ds_counterfactual, success) -real_success_ds = Subset(ds_real, success) - - -# %% tags=["solution"] -######################## -# Solution to Task 5.1 # -######################## - -# Get the samples where the real is correct -correct_real = real_pred == real_gt - -# HINT GABA is class 1 and ACh is class 0 -target = 1 - real_gt - -# Get the samples where the counterfactual has reached the target -correct_cf = cf_pred == target - -# Successful conversions -success = np.where(np.logical_and(correct_real, correct_cf))[0] - -# Create datasets with only the successes -cf_success_ds = Subset(ds_counterfactual, success) -real_success_ds = Subset(ds_real, success) # %% [markdown] tags=[] -# To check that we have got it right, let us get the accuracy on the best 100 vs the worst 100 samples: - -# %% tags=[] -model = model.to("cuda") - -# %% tags=[] -real_true, real_pred = predict(real_success_ds, "Real") -cf_true, cf_pred = predict(cf_success_ds, "Counterfactuals") - -print( - "Accuracy of the classifier on successful real images", - accuracy_score(real_true, real_pred), -) -print( - "Accuracy of the classifier on successful counterfactual images", - accuracy_score(cf_true, cf_pred), -) - -# %% [markdown] tags=[] -# ### Creating hybrids from attributions -# -# Now that we have a set of successfully translated counterfactuals, we can use them as a baseline for our attribution. -# If you remember from earlier, `IntegratedGradients` does a interpolation between the model gradients at the baseline and the model gradients at the sample. Here, we're also going to be doing an interpolation between the baseline image and the sample image, creating a hybrid! -# -# To do this, we will take the sample image and mask out all of the pixels in the attribution. We will then replace these masked out pixels by the equivalent values in the counterfactual. So we'll have a hybrid image that is like the original everywhere except in the areas that matter for classification. - -# %% tags=[] -dataloader_real = DataLoader(real_success_ds, batch_size=10) -dataloader_counter = DataLoader(cf_success_ds, batch_size=10) - -# %% tags=[] -# %%time -with torch.no_grad(): - model.to(device) - # Create an integrated gradients object. - # integrated_gradients = IntegratedGradients(model) - # Generated attributions on integrated gradients - attributions = np.vstack( - [ - integrated_gradients.attribute( - real.to(device), - target=target.to(device), - baselines=counterfactual.to(device), - ) - .cpu() - .numpy() - for (real, target), (counterfactual, _) in zip( - dataloader_real, dataloader_counter - ) - ] - ) - -# %% - -# %% tags=[] -# Functions for creating an interactive visualization of our attributions -model.cpu() - -import matplotlib - -cmap = matplotlib.cm.get_cmap("viridis") -colors = cmap([0, 255]) - - -@torch.no_grad() -def get_classifications(image, counter, hybrid): - model.eval() - class_idx = [full_dataset.classes.index(c) for c in classes] - tensor = torch.from_numpy(np.stack([image, counter, hybrid])).float() - with torch.no_grad(): - logits = model(tensor)[:, class_idx] - probs = torch.nn.Softmax(dim=1)(logits) - pred, counter_pred, hybrid_pred = probs - return pred.numpy(), counter_pred.numpy(), hybrid_pred.numpy() - - -def visualize_counterfactuals(idx, threshold=0.1): - image = real_success_ds[idx][0].numpy() - counter = cf_success_ds[idx][0].numpy() - mask = get_mask(attributions[idx], threshold) - hybrid = (1 - mask) * image + mask * counter - nan_mask = copy.deepcopy(mask) - nan_mask[nan_mask != 0] = 1 - nan_mask[nan_mask == 0] = np.nan - # PLOT - fig, axes = plt.subplot_mosaic( - """ - mmm.ooo.ccc.hhh - mmm.ooo.ccc.hhh - mmm.ooo.ccc.hhh - ....ggg.fff.ppp - """, - figsize=(20, 5), - ) - # Original - viz.visualize_image_attr( - np.transpose(mask, (1, 2, 0)), - np.transpose(image, (1, 2, 0)), - method="blended_heat_map", - sign="absolute_value", - show_colorbar=True, - title="Mask", - use_pyplot=False, - plt_fig_axis=(fig, axes["m"]), - ) - # Original - axes["o"].imshow(image.squeeze(), cmap="gray") - axes["o"].set_title("Original", fontsize=24) - # Counterfactual - axes["c"].imshow(counter.squeeze(), cmap="gray") - axes["c"].set_title("Counterfactual", fontsize=24) - # Hybrid - axes["h"].imshow(hybrid.squeeze(), cmap="gray") - axes["h"].set_title("Hybrid", fontsize=24) - # Mask - pred, counter_pred, hybrid_pred = get_classifications(image, counter, hybrid) - axes["g"].barh(classes, pred, color=colors) - axes["f"].barh(classes, counter_pred, color=colors) - axes["p"].barh(classes, hybrid_pred, color=colors) - for ix in ["m", "o", "c", "h"]: - axes[ix].axis("off") - - for ix in ["g", "f", "p"]: - for tick in axes[ix].get_xticklabels(): - tick.set_rotation(90) - axes[ix].set_xlim(0, 1) - - -# %% [markdown] tags=[] -#

Task 5.2: Observing the effect of the changes on the classifier

-# Below is a small widget to interact with the above analysis. As you change the `threshold`, see how the prediction of the hybrid changes. -# At what point does it swap over? -# -# If you want to see different samples, slide through the `idx`. -#
- -# %% tags=[] -interact(visualize_counterfactuals, idx=(0, 99), threshold=(0.0, 1.0, 0.05)) - -# %% [markdown] -# HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out. - -# %% tags=[] -# Choose your own adventure -# idx = 0 -# threshold = 0.1 - -# # Plotting :) -# visualize_counterfactuals(idx, threshold) - -# %% [markdown] tags=[] -#
-#

Questions

-# -# - Can you find features that define either of the two classes? -# - How consistent are they across the samples? -# - Is there a range of thresholds where most of the hybrids swap over to the target class? (If you want to see that area, try to change the range of thresholds in the slider by setting `threshold=(minimum_value, maximum_value, step_size)` -# -# Feel free to discuss your answers on the exercise chat! -#
- -# %% [markdown] tags=[] -#
-#

The End.

-# Go forth and train some GANs! -#
- -# %% [markdown] tags=[] +# TODO # ## Going Further # # Here are some ideas for how to continue with this notebook: