diff --git a/solution.py b/solution.py index f065a11..1e3bc59 100644 --- a/solution.py +++ b/solution.py @@ -494,8 +494,7 @@ def forward(self, x, y): cycle_loss_fn = nn.L1Loss() # %% [markdown] tags=[] -# Stuff about the dataloader - +# To load the data as batches, with shuffling and other useful features, we will use a `DataLoader`. # %% from torch.utils.data import DataLoader @@ -504,7 +503,9 @@ def forward(self, x, y): ) # We will use the same dataset as before # %% [markdown] tags=[] -# TODO - Describe set_requires_grad +# As we stated earlier, it is important to make sure when each network is being trained when working with a GAN. +# Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing! +# `set_requires_grad` is a function that allows us to determine when the weights of a network are trainable (if it is `True`) or not (if it is `False`). # %% def set_requires_grad(module, value=True): """Sets `requires_grad` on a `module`'s parameters to `value`""" @@ -512,8 +513,15 @@ def set_requires_grad(module, value=True): param.requires_grad = value # %% [markdown] tags=[] -# TODO - Describe EMA - +# Another consequence of adversarial training is that it is very unstable. +# While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model. +# To force some stability back into the training, we will use Exponential Moving Averages (EMA). +# +# In essence, each time we update the generator's weights, we will also update the EMA model's weights as an average of all the generator's previous weights as well as the current update. +# A certain weight is given to the previous weights, which is what ensures that the EMA update remains rather smooth over the training period. +# Each epoch, we will then copy the EMA model's weights back to the generator. +# This is a common technique used in GAN training to stabilize the training process. +# Pay attention to what this does to the loss during the training process! # %% from copy import deepcopy @@ -538,16 +546,19 @@ def copy_parameters(source_model, target_model): # %% [markdown] tags=[] #
+# %% [markdown] tags=[] +# Once you're happy with your choices, run the training loop! 🚂 🚋 🚋 🚋 # %% tags=["task"] from tqdm import tqdm # This is a nice library for showing progress bars @@ -708,8 +719,6 @@ def copy_parameters(source_model, target_model): # %% [markdown] tags=[] -# ...this time again. 🚂 🚋 🚋 🚋 -# # Once training is complete, we can plot the losses to see how well the model is doing. # %% plt.plot(losses["cycle"], label="Cycle loss") @@ -901,7 +910,6 @@ def copy_parameters(source_model, target_model): # Generated attributions on integrated gradients attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) - # %% Another visualization function def visualize_color_attribution_and_counterfactual( attribution, original_image, counterfactual_image @@ -922,7 +930,6 @@ def visualize_color_attribution_and_counterfactual( ax2.axis("off") plt.show() - # %% for idx in range(batch_size): print("Source class:", y[idx].item()) @@ -966,9 +973,11 @@ def visualize_color_attribution_and_counterfactual( # Let's take a look at the style space. # We will use the style encoder to encode the style of the images and then use PCA to visualize it. # -# TODO # %% +from sklearn.decomposition import PCA + + styles = [] labels = [] for img, label in random_test_mnist: @@ -978,8 +987,6 @@ def visualize_color_attribution_and_counterfactual( labels.append(label) # PCA -from sklearn.decomposition import PCA - pca = PCA(n_components=2) styles_pca = pca.fit_transform(styles) @@ -999,22 +1006,106 @@ def visualize_color_attribution_and_counterfactual( # We know that color is important. Does interpreting the style space as colors help us understand better? # # Let's use the style space to color the PCA plot. +# (Note: there is no code to write here, just run the cell and answer the questions below) # # TODO WIP HERE +# %% +normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1) -# %% [markdown] tags=[] -# ## Going Further -# -# Here are some ideas for how to continue with this notebook: -# -# 1. Improve the classifier. This code uses a VGG network for the classification. On the synapse dataset, we will get a validation accuracy of around 80%. Try to see if you can improve the classifier accuracy. -# * (easy) Data augmentation: The training code for the classifier is quite simple in this example. Enlarge the amount of available training data by adding augmentations (transpose and mirror the images, add noise, change the intensity, etc.). -# * (easy) Network architecture: The VGG network has a few parameters that one can tune. Try a few to see what difference it makes. -# * (easy) Inspect the classifier predictions: Take random samples from the test dataset and classify them. Show the images together with their predicted and actual labels. -# * (medium) Other networks: Try different architectures (e.g., a [ResNet](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/#resnet-from-scratch)) and see if the accuracy can be improved. +# Plot the PCA again! +plt.figure(figsize=(10, 10)) +plt.scatter( + styles_pca[:, 0], + styles_pca[:, 1], + c=normalized_styles, +) +plt.show() +# %% [markdown] +#