Skip to content

Commit

Permalink
torch inference mode and html changes
Browse files Browse the repository at this point in the history
  • Loading branch information
neptunes5thmoon committed Aug 26, 2024
1 parent 14e66a9 commit 3fd89af
Showing 1 changed file with 35 additions and 35 deletions.
70 changes: 35 additions & 35 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,20 +849,20 @@ def copy_parameters(source_model, target_model):
predictions = []
source_labels = []
target_labels = []

for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# TODO Create the counterfactual
x_fake = generator(x.unsqueeze(0).to(device), ...)
# TODO Predict the class of the counterfactual image
pred = model(...)

# TODO Store the source and target labels
source_labels.append(...) # The original label of the image
target_labels.append(...) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())
with torch.inference_mode():
for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# TODO Create the counterfactual
x_fake = generator(x.unsqueeze(0).to(device), ...)
# TODO Predict the class of the counterfactual image
pred = model(...)

# TODO Store the source and target labels
source_labels.append(...) # The original label of the image
target_labels.append(...) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())
# %% tags=["solution"]
num_images = 1000
random_test_mnist = torch.utils.data.Subset(
Expand All @@ -873,22 +873,22 @@ def copy_parameters(source_model, target_model):
predictions = []
source_labels = []
target_labels = []

for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# Create the counterfactual
x_fake = generator(
x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)
)
# Predict the class of the counterfactual image
pred = model(x_fake)

# Store the source and target labels
source_labels.append(y) # The original label of the image
target_labels.append(lbl) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())
with torch.inference_mode():
for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# Create the counterfactual
x_fake = generator(
x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)
)
# Predict the class of the counterfactual image
pred = model(x_fake)

# Store the source and target labels
source_labels.append(y) # The original label of the image
target_labels.append(lbl) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())

# %% [markdown] tags=[]
# Let's plot the confusion matrix for the counterfactual images.
Expand Down Expand Up @@ -1161,7 +1161,6 @@ def visualize_color_attribution_and_counterfactual(
plt.legend()
plt.show()

# %%
# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
Expand All @@ -1187,15 +1186,16 @@ def visualize_color_attribution_and_counterfactual(
# If you have extra time, you can try to break the StarGAN!
# There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?
# Some things you might want to try:
# - What happens if you don't use the EMA model?
# - What happens if you change the learning rates?
# - What happens if you add a Sigmoid activation to the output of the style encoder?
# <li> What happens if you don't use the EMA model? </li>
# <li> What happens if you change the learning rates? </li>
# <li> What happens if you add a Sigmoid activation to the output of the style encoder? </li>
# See what else you can think of, and see how finnicky training a GAN can be!

## %% [markdown] tags=["solution"]
# The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.
# Check your style space again to see if you can see the patterns now!
# %% tags=["solution"]

## %% tags=["solution"]
# Let's plot the colormaps
import matplotlib as mpl
import numpy as np
Expand Down

0 comments on commit 3fd89af

Please sign in to comment.