From 3fd89af87305240da18978cc1154887d0822bbe6 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Mon, 26 Aug 2024 20:15:11 +0000 Subject: [PATCH] torch inference mode and html changes --- solution.py | 70 ++++++++++++++++++++++++++--------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/solution.py b/solution.py index 2d5ed2d..1768933 100644 --- a/solution.py +++ b/solution.py @@ -849,20 +849,20 @@ def copy_parameters(source_model, target_model): predictions = [] source_labels = [] target_labels = [] - -for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): - for lbl in range(4): - # TODO Create the counterfactual - x_fake = generator(x.unsqueeze(0).to(device), ...) - # TODO Predict the class of the counterfactual image - pred = model(...) - - # TODO Store the source and target labels - source_labels.append(...) # The original label of the image - target_labels.append(...) # The desired label of the counterfactual image - # Store the counterfactual image and prediction - counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() - predictions.append(pred.argmax().item()) +with torch.inference_mode(): + for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # TODO Create the counterfactual + x_fake = generator(x.unsqueeze(0).to(device), ...) + # TODO Predict the class of the counterfactual image + pred = model(...) + + # TODO Store the source and target labels + source_labels.append(...) # The original label of the image + target_labels.append(...) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() + predictions.append(pred.argmax().item()) # %% tags=["solution"] num_images = 1000 random_test_mnist = torch.utils.data.Subset( @@ -873,22 +873,22 @@ def copy_parameters(source_model, target_model): predictions = [] source_labels = [] target_labels = [] - -for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): - for lbl in range(4): - # Create the counterfactual - x_fake = generator( - x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device) - ) - # Predict the class of the counterfactual image - pred = model(x_fake) - - # Store the source and target labels - source_labels.append(y) # The original label of the image - target_labels.append(lbl) # The desired label of the counterfactual image - # Store the counterfactual image and prediction - counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() - predictions.append(pred.argmax().item()) +with torch.inference_mode(): + for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # Create the counterfactual + x_fake = generator( + x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device) + ) + # Predict the class of the counterfactual image + pred = model(x_fake) + + # Store the source and target labels + source_labels.append(y) # The original label of the image + target_labels.append(lbl) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() + predictions.append(pred.argmax().item()) # %% [markdown] tags=[] # Let's plot the confusion matrix for the counterfactual images. @@ -1161,7 +1161,6 @@ def visualize_color_attribution_and_counterfactual( plt.legend() plt.show() -# %% # %% [markdown] #

Questions

#