Skip to content

Commit

Permalink
Finish style space, explanations, and conclusion
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Aug 15, 2024
1 parent b4595ab commit f864649
Showing 1 changed file with 119 additions and 28 deletions.
147 changes: 119 additions & 28 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,7 @@ def forward(self, x, y):
cycle_loss_fn = nn.L1Loss()

# %% [markdown] tags=[]
# Stuff about the dataloader

# To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`.
# %%
from torch.utils.data import DataLoader

Expand All @@ -504,16 +503,25 @@ def forward(self, x, y):
) # We will use the same dataset as before

# %% [markdown] tags=[]
# TODO - Describe set_requires_grad
# As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.
# Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!
# `set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`).
# %%
def set_requires_grad(module, value=True):
"""Sets `requires_grad` on a `module`'s parameters to `value`"""
for param in module.parameters():
param.requires_grad = value

# %% [markdown] tags=[]
# TODO - Describe EMA

# Another consequence of adversarial training is that it is very unstable.
# While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.
# To force some stability back into the training, we will use Exponential Moving Averages (EMA).
#
# In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update.
# A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period.
# Each epoch, we will then copy the EMA model's weights back to the generator.
# This is a common technique used in GAN training to stabilize the training process.
# Pay attention to what this does to the loss during the training process!
# %%
from copy import deepcopy

Expand All @@ -538,16 +546,19 @@ def copy_parameters(source_model, target_model):

# %% [markdown] tags=[]
# <div class="alert alert-banner alert-info"><h4>Task 3.2: Training!</h4>
#
# TODO - the task is to choose where to apply set_requires_grad
# You were given several different options in the training code below. In each case, one of the options will work, and the other will not.
# Comment out the option that you think will not work.
# <ul>
# <li>Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator</li>
# <li>Choose the values of `set_requires_grad`, again. Hint: you may want to switch</li>
# <li>Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?</li>
# . <li>Apply the EMA update. Hint: which model do you want to update? You can look again at the code we wrote above.</li>
# </ul>
# Let's train the StarGAN one batch a time.
# While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?
# </div>
# %% [markdown] tags=[]
# Once you're happy with your choices, run the training loop! &#x1F682; &#x1F68B; &#x1F68B; &#x1F68B;
# %% tags=["task"]
from tqdm import tqdm # This is a nice library for showing progress bars

Expand Down Expand Up @@ -708,8 +719,6 @@ def copy_parameters(source_model, target_model):


# %% [markdown] tags=[]
# ...this time again. &#x1F682; &#x1F68B; &#x1F68B; &#x1F68B;
#
# Once training is complete, we can plot the losses to see how well the model is doing.
# %%
plt.plot(losses["cycle"], label="Cycle loss")
Expand Down Expand Up @@ -901,7 +910,6 @@ def copy_parameters(source_model, target_model):
# Generated attributions on integrated gradients
attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)


# %% Another visualization function
def visualize_color_attribution_and_counterfactual(
attribution, original_image, counterfactual_image
Expand All @@ -922,7 +930,6 @@ def visualize_color_attribution_and_counterfactual(
ax2.axis("off")
plt.show()


# %%
for idx in range(batch_size):
print("Source class:", y[idx].item())
Expand Down Expand Up @@ -966,9 +973,11 @@ def visualize_color_attribution_and_counterfactual(
# Let's take a look at the style space.
# We will use the style encoder to encode the style of the images and then use PCA to visualize it.
# </div>
# TODO

# %%
from sklearn.decomposition import PCA


styles = []
labels = []
for img, label in random_test_mnist:
Expand All @@ -978,8 +987,6 @@ def visualize_color_attribution_and_counterfactual(
labels.append(label)

# PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
styles_pca = pca.fit_transform(styles)

Expand All @@ -999,22 +1006,106 @@ def visualize_color_attribution_and_counterfactual(
# We know that color is important. Does interpreting the style space as colors help us understand better?
#
# Let's use the style space to color the PCA plot.
# (Note: there is no code to write here, just run the cell and answer the questions below)
# </div>
# TODO WIP HERE
# %%
normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)

# %% [markdown] tags=[]
# ## Going Further
#
# Here are some ideas for how to continue with this notebook:
#
# 1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy.
# * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.).
# * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes.
# * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels.
# * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved.
# Plot the PCA again!
plt.figure(figsize=(10, 10))
plt.scatter(
styles_pca[:, 0],
styles_pca[:, 1],
c=normalized_styles,
)
plt.show()
# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> Do the colors match those that you have seen in the data?</li>
# <li> Can you see any patterns in the colors? Is the space smooth, for example?</li>
# </ul>
# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Using the images to color the style space</h3>
# Finally, let's just use the colors from the images themselves!
# All of the non-zero values in the image can be averaged to get a color.
#
# 2. Explore the CycleGAN.
# * (easy) The example code below shows how to translate between GABA and acetylcholine. Try different combinations. Can you start to see differences between some pairs of classes? Which are the ones where the differences are the most or the least obvious? Can you see any differences that aren't well described by the mask? How would you describe these?
# Let's get that color, then plot the style space again.
# (Note: once again, no coding needed here, just run the cell and think about the results with the questions below)
# </div>
# %% tags=["solution"]
tol = 1e-6

colors = []
for x, y in random_test_mnist:
non_zero = x[x > tol]
colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze())

# Plot the PCA again!
plt.figure(figsize=(10, 10))
plt.scatter(
styles_pca[:, 0],
styles_pca[:, 1],
c=normalized_styles,
)
plt.show()

# %%
# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> Do the colors match those that you have seen in the data?</li>
# <li> Can you see any patterns in the colors?</li>
# <li> Can you guess what the classes correspond to?</li>

# %% [markdown]
# <div class="alert alert-block alert-success"><h2>Checkpoint 5</h2>
# Congratulations! You have made it to the end of the exercise!
# You have:
# - Created a StarGAN that can change the class of an image
# - Evaluated the StarGAN on unseen data
# - Used the StarGAN to create counterfactual images
# - Used the counterfactual images to highlight the differences between classes
# - Used the style space to understand the differences between classes
#
# 3. Try on your own data!
# * Have a look at how the synapse images are organized in `data/raw/synapses`. Copy the directory structure and use your own images. Depending on your data, you might have to adjust the image size (128x128 for the synapses) and number of channels in the VGG network and CycleGAN code.
# If you have any questions, feel free to ask them in the chat!
# And check the Solutions exercise for a definite answer to how these classes are defined!

# %% [markdown] tags=["solution"]
# The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.
# Check your style space again to see if you can see the patterns now!
# %% tags=["solution"]
# Let's plot the colormaps
import matplotlib as mpl
import numpy as np


def plot_color_gradients(cmap_list):
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))

# Create figure and adjust figure height to number of colormaps
nrows = len(cmap_list)
figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22
fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh))
fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99)

for ax, name in zip(axs, cmap_list):
ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name])
ax.text(
-0.01,
0.5,
name,
va="center",
ha="right",
fontsize=10,
transform=ax.transAxes,
)

# Turn off *all* ticks & spines, not just the ones with colormaps.
for ax in axs:
ax.set_axis_off()


plot_color_gradients(["spring", "summer", "autumn", "winter"])

0 comments on commit f864649

Please sign in to comment.