Skip to content

Commit

Permalink
Fix numbering, missing todos, and plotting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Aug 15, 2024
1 parent 33a6110 commit c1a6e28
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def forward(self, x, y):
# We will have two different optimizers, one for the Generator and one for the Discriminator.
#
# %%
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)
optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
# %% [markdown] tags=[]
#
Expand Down Expand Up @@ -545,7 +545,7 @@ def copy_parameters(source_model, target_model):
generator_ema = generator_ema.to(device)

# %% [markdown] tags=[]
# <div class="alert alert-banner alert-info"><h4>Task 3.2: Training!</h4>
# <div class="alert alert-banner alert-info"><h4>Task 3.3: Training!</h4>
# You were given several different options in the training code below. In each case, one of the options will work, and the other will not.
# Comment out the option that you think will not work.
# <ul>
Expand Down Expand Up @@ -760,7 +760,7 @@ def copy_parameters(source_model, target_model):
# </div>

# %% [markdown] tags=[]
# # Part 4: Evaluating the GAN
# # Part 4: Evaluating the GAN and creating Counterfactuals

# %% [markdown] tags=[]
# ## Creating counterfactuals
Expand All @@ -777,7 +777,7 @@ def copy_parameters(source_model, target_model):


for i in range(4):
options = np.where(test_mnist.targets == i)[0]
options = np.where(test_mnist.conditions == i)[0]
# Note that you can change the image index if you want to use a different prototype.
image_index = 0
x, y = test_mnist[options[image_index]]
Expand All @@ -795,7 +795,7 @@ def copy_parameters(source_model, target_model):
# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 4.1: Create counterfactuals</h3>
# <div class="alert alert-block alert-info"><h3>Task 4: Create counterfactuals</h3>
# In the below, we will store the counterfactual images in the `counterfactuals` array.
#
# <ul>
Expand Down Expand Up @@ -887,9 +887,6 @@ def copy_parameters(source_model, target_model):
# </ul>
# </div>

# %% [markdown]
# # Part 5: Highlighting Class-Relevant Differences

# %% [markdown]
# At this point we have:
# - A classifier that can differentiate between image of different classes
Expand Down Expand Up @@ -954,7 +951,7 @@ def visualize_color_attribution_and_counterfactual(
# - Used the counterfactual images to highlight the differences between classes
#
# %% [markdown]
# # Part 6: Exploring the Style Space, finding the answer
# # Part 5: Exploring the Style Space, finding the answer
# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!
#
# Here is an example of two images that are very similar in color, but are of different classes.
Expand Down Expand Up @@ -1002,15 +999,17 @@ def visualize_color_attribution_and_counterfactual(
plt.show()

# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 6.2: Adding color to the style space</h3>
# <div class="alert alert-block alert-info"><h3>Task 5.1: Adding color to the style space</h3>
# We know that color is important. Does interpreting the style space as colors help us understand better?
#
# Let's use the style space to color the PCA plot.
# (Note: there is no code to write here, just run the cell and answer the questions below)
# </div>
# TODO WIP HERE
# %%
normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)
styles = np.array(styles)
normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(
styles, axis=1, keepdims=True
)

# Plot the PCA again!
plt.figure(figsize=(10, 10))
Expand All @@ -1027,27 +1026,22 @@ def visualize_color_attribution_and_counterfactual(
# <li> Can you see any patterns in the colors? Is the space smooth, for example?</li>
# </ul>
# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Using the images to color the style space</h3>
# <div class="alert alert-block alert-info"><h3>Task 5.2: Using the images to color the style space</h3>
# Finally, let's just use the colors from the images themselves!
# All of the non-zero values in the image can be averaged to get a color.
# The maximum value in the image (since they are "black-and-color") can be used as a color!
#
# Let's get that color, then plot the style space again.
# (Note: once again, no coding needed here, just run the cell and think about the results with the questions below)
# </div>
# %% tags=["solution"]
tol = 1e-6

colors = []
for x, y in random_test_mnist:
non_zero = x[x > tol]
colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze())
colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]

# Plot the PCA again!
plt.figure(figsize=(10, 10))
plt.scatter(
styles_pca[:, 0],
styles_pca[:, 1],
c=normalized_styles,
c=colors,
)
plt.show()

Expand Down

0 comments on commit c1a6e28

Please sign in to comment.