\n",
@@ -534,7 +541,7 @@
},
{
"cell_type": "markdown",
- "id": "2dc92a5c",
+ "id": "afc728f6",
"metadata": {},
"source": [
"
Checkpoint 2
\n",
@@ -554,14 +561,14 @@
},
{
"cell_type": "markdown",
- "id": "c0727f2f",
+ "id": "5731c94d",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"# Part 3: Train a GAN to Translate Images\n",
"\n",
- "To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n",
+ "To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.\n",
"This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.\n",
"\n",
"**What is a counterfactual?**\n",
@@ -582,7 +589,7 @@
},
{
"cell_type": "markdown",
- "id": "147a10f1",
+ "id": "017d5942",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -605,7 +612,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0b789be2",
+ "id": "16c7b4c1",
"metadata": {},
"outputs": [],
"source": [
@@ -637,7 +644,7 @@
},
{
"cell_type": "markdown",
- "id": "460878cc",
+ "id": "ebf7db5f",
"metadata": {
"lines_to_next_cell": 0
},
@@ -652,7 +659,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5cf75884",
+ "id": "a1a4dd45",
"metadata": {
"tags": [
"solution"
@@ -669,7 +676,7 @@
},
{
"cell_type": "markdown",
- "id": "dc70737d",
+ "id": "5286f95c",
"metadata": {
"tags": []
},
@@ -684,7 +691,7 @@
},
{
"cell_type": "markdown",
- "id": "6bd563e2",
+ "id": "e16b6706",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -701,7 +708,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6fa80433",
+ "id": "91355252",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -715,7 +722,7 @@
},
{
"cell_type": "markdown",
- "id": "955d9981",
+ "id": "100f8d9d",
"metadata": {
"lines_to_next_cell": 0
},
@@ -726,7 +733,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7f71dbdc",
+ "id": "7ad040a9",
"metadata": {},
"outputs": [],
"source": [
@@ -736,7 +743,7 @@
},
{
"cell_type": "markdown",
- "id": "bd0d99c9",
+ "id": "9196de07",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -754,7 +761,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "cb6bf33f",
+ "id": "6e5f2d8f",
"metadata": {
"lines_to_next_cell": 0
},
@@ -766,7 +773,7 @@
},
{
"cell_type": "markdown",
- "id": "803dad9e",
+ "id": "03ae0868",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -785,7 +792,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1ed827ff",
+ "id": "b7ca208e",
"metadata": {},
"outputs": [],
"source": [
@@ -794,7 +801,7 @@
},
{
"cell_type": "markdown",
- "id": "5166c91e",
+ "id": "6d4acb54",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -810,7 +817,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "756f9a51",
+ "id": "18baee07",
"metadata": {},
"outputs": [],
"source": [
@@ -819,7 +826,7 @@
},
{
"cell_type": "markdown",
- "id": "625bb412",
+ "id": "55dbff92",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -831,10 +838,8 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8efeda25",
- "metadata": {
- "lines_to_next_cell": 1
- },
+ "id": "a7dfdc87",
+ "metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
@@ -846,7 +851,7 @@
},
{
"cell_type": "markdown",
- "id": "613e2c1f",
+ "id": "410575a9",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -860,10 +865,8 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "971e4622",
- "metadata": {
- "lines_to_next_cell": 1
- },
+ "id": "3fbe0be1",
+ "metadata": {},
"outputs": [],
"source": [
"def set_requires_grad(module, value=True):\n",
@@ -874,7 +877,7 @@
},
{
"cell_type": "markdown",
- "id": "d86d0ea1",
+ "id": "54e7b00b",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -894,7 +897,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ee7e26ce",
+ "id": "654227d1",
"metadata": {},
"outputs": [],
"source": [
@@ -918,7 +921,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6b0e8161",
+ "id": "54a1ee23",
"metadata": {},
"outputs": [],
"source": [
@@ -928,7 +931,7 @@
},
{
"cell_type": "markdown",
- "id": "854f274b",
+ "id": "d1d4c4d6",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -950,7 +953,7 @@
},
{
"cell_type": "markdown",
- "id": "7783da15",
+ "id": "973a3066",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -962,7 +965,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "14812a7c",
+ "id": "81c01fb9",
"metadata": {
"lines_to_next_cell": 2,
"tags": [
@@ -1032,7 +1035,7 @@
},
{
"cell_type": "markdown",
- "id": "5809a842",
+ "id": "06637a58",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1044,7 +1047,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "59ea06d6",
+ "id": "79d69313",
"metadata": {},
"outputs": [],
"source": [
@@ -1057,7 +1060,7 @@
},
{
"cell_type": "markdown",
- "id": "86c8ae57",
+ "id": "f8ec10ea",
"metadata": {
"tags": []
},
@@ -1072,7 +1075,7 @@
},
{
"cell_type": "markdown",
- "id": "8316db9c",
+ "id": "5243c266",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1084,16 +1087,20 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d9fbc729",
+ "id": "11477b53",
"metadata": {},
"outputs": [],
"source": [
"idx = 0\n",
"fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n",
"axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())\n",
+ "axs[0].set_title(\"Input image\")\n",
"axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())\n",
+ "axs[1].set_title(\"Style image\")\n",
"axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())\n",
+ "axs[2].set_title(\"Generated image\")\n",
"axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())\n",
+ "axs[3].set_title(\"Cycled image\")\n",
"\n",
"for ax in axs:\n",
" ax.axis(\"off\")\n",
@@ -1103,7 +1110,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e2a81fb3",
+ "id": "9d8b0179",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1112,7 +1119,7 @@
},
{
"cell_type": "markdown",
- "id": "e039a039",
+ "id": "bc36ab42",
"metadata": {
"tags": []
},
@@ -1128,7 +1135,7 @@
},
{
"cell_type": "markdown",
- "id": "7f4210fd",
+ "id": "35e6b13d",
"metadata": {
"tags": []
},
@@ -1138,7 +1145,7 @@
},
{
"cell_type": "markdown",
- "id": "faf3eac1",
+ "id": "e246771f",
"metadata": {
"tags": []
},
@@ -1155,7 +1162,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b56b0ac0",
+ "id": "cbb21039",
"metadata": {
"title": "Loading the test dataset"
},
@@ -1175,7 +1182,7 @@
},
{
"cell_type": "markdown",
- "id": "e0ded76f",
+ "id": "88770593",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1187,7 +1194,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fdcd0b4c",
+ "id": "387f7a94",
"metadata": {},
"outputs": [],
"source": [
@@ -1200,7 +1207,7 @@
},
{
"cell_type": "markdown",
- "id": "a0a01596",
+ "id": "67099727",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1210,7 +1217,7 @@
},
{
"cell_type": "markdown",
- "id": "5088af03",
+ "id": "5850a3c5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1228,7 +1235,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b4cb730b",
+ "id": "c375ad89",
"metadata": {
"tags": [
"solution"
@@ -1265,7 +1272,7 @@
},
{
"cell_type": "markdown",
- "id": "c87c89df",
+ "id": "049af8ad",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1277,17 +1284,20 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7db32745",
+ "id": "47a3f34c",
"metadata": {},
"outputs": [],
"source": [
"cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n",
- "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")"
+ "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")\n",
+ "plt.ylabel(\"True\")\n",
+ "plt.xlabel(\"Predicted\")\n",
+ "plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "ed5aafe5",
+ "id": "b3dfc433",
"metadata": {
"tags": []
},
@@ -1302,7 +1312,7 @@
},
{
"cell_type": "markdown",
- "id": "cdba36a8",
+ "id": "64ff01c8",
"metadata": {
"tags": []
},
@@ -1313,7 +1323,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "de504515",
+ "id": "8c55b3fb",
"metadata": {},
"outputs": [],
"source": [
@@ -1327,7 +1337,7 @@
},
{
"cell_type": "markdown",
- "id": "d460f4eb",
+ "id": "708e10ac",
"metadata": {
"tags": []
},
@@ -1342,7 +1352,7 @@
},
{
"cell_type": "markdown",
- "id": "59041d52",
+ "id": "b2eafec3",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1357,10 +1367,8 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "752c4ee3",
- "metadata": {
- "lines_to_next_cell": 1
- },
+ "id": "43209aa2",
+ "metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
@@ -1379,9 +1387,8 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1a401326",
+ "id": "3e969580",
"metadata": {
- "lines_to_next_cell": 1,
"title": "Another visualization function"
},
"outputs": [],
@@ -1409,7 +1416,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "affc1177",
+ "id": "c8b8b46e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1425,7 +1432,7 @@
},
{
"cell_type": "markdown",
- "id": "194ac43d",
+ "id": "7f80a7f8",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1441,7 +1448,7 @@
},
{
"cell_type": "markdown",
- "id": "f54356bc",
+ "id": "52bdea35",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1456,7 +1463,7 @@
},
{
"cell_type": "markdown",
- "id": "473e32d8",
+ "id": "f0d787ae",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1478,22 +1485,20 @@
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "id": "0d29cfae",
+ "cell_type": "markdown",
+ "id": "39d99dfb",
"metadata": {},
- "outputs": [],
"source": [
- "#
Task 6.1: Explore the style space
\n",
- "# Let's take a look at the style space.\n",
- "# We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n",
- "# "
+ "
Task 5.1: Explore the style space
\n",
+ "Let's take a look at the style space.\n",
+ "We will use the style encoder to encode the style of the images and then use PCA to visualize it.\n",
+ ""
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "3d2d9a2d",
+ "id": "dab223d8",
"metadata": {},
"outputs": [],
"source": [
@@ -1513,20 +1518,22 @@
"styles_pca = pca.fit_transform(styles)\n",
"\n",
"# Plot the PCA\n",
+ "markers = [\"o\", \"s\", \"P\", \"^\"]\n",
"plt.figure(figsize=(10, 10))\n",
"for i in range(4):\n",
" plt.scatter(\n",
" styles_pca[np.array(labels) == i, 0],\n",
" styles_pca[np.array(labels) == i, 1],\n",
+ " marker=markers[i],\n",
" label=f\"Class {i}\",\n",
" )\n",
- "\n",
+ "plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "29cd4445",
+ "id": "d6ab7be4",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1542,7 +1549,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d9a0f9e5",
+ "id": "e678d3af",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1555,17 +1562,21 @@
"\n",
"# Plot the PCA again!\n",
"plt.figure(figsize=(10, 10))\n",
- "plt.scatter(\n",
- " styles_pca[:, 0],\n",
- " styles_pca[:, 1],\n",
- " c=normalized_styles,\n",
- ")\n",
+ "for i in range(4):\n",
+ " plt.scatter(\n",
+ " styles_pca[np.array(labels) == i, 0],\n",
+ " styles_pca[np.array(labels) == i, 1],\n",
+ " c=normalized_styles[np.array(labels) == i],\n",
+ " marker=markers[i],\n",
+ " label=f\"Class {i}\",\n",
+ " )\n",
+ "plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "f508f4cc",
+ "id": "06b219da",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1579,7 +1590,7 @@
},
{
"cell_type": "markdown",
- "id": "31527df5",
+ "id": "585cf589",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1596,30 +1607,30 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "94b216f2",
- "metadata": {
- "tags": [
- "solution"
- ]
- },
+ "id": "164eb5e1",
+ "metadata": {},
"outputs": [],
"source": [
- "colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]\n",
+ "colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])\n",
"\n",
"# Plot the PCA again!\n",
"plt.figure(figsize=(10, 10))\n",
- "plt.scatter(\n",
- " styles_pca[:, 0],\n",
- " styles_pca[:, 1],\n",
- " c=colors,\n",
- ")\n",
+ "for i in range(4):\n",
+ " plt.scatter(\n",
+ " styles_pca[np.array(labels) == i, 0],\n",
+ " styles_pca[np.array(labels) == i, 1],\n",
+ " c=colors[np.array(labels) == i],\n",
+ " marker=markers[i],\n",
+ " label=f\"Class {i}\",\n",
+ " )\n",
+ "plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "2f77a1be",
+ "id": "9cbe1f3b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1628,7 +1639,7 @@
},
{
"cell_type": "markdown",
- "id": "06b8ef1a",
+ "id": "0bcd9514",
"metadata": {},
"source": [
"
Questions
\n",
@@ -1640,7 +1651,7 @@
},
{
"cell_type": "markdown",
- "id": "a3953322",
+ "id": "20be93cd",
"metadata": {},
"source": [
"
Checkpoint 5
\n",
@@ -1658,7 +1669,7 @@
},
{
"cell_type": "markdown",
- "id": "4c2eb6f3",
+ "id": "9d8664fd",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1673,7 +1684,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b972033b",
+ "id": "fe781dd6",
"metadata": {
"tags": [
"solution"
diff --git a/solution.py b/solution.py
index 318c44e..77e0155 100644
--- a/solution.py
+++ b/solution.py
@@ -109,24 +109,28 @@
cm = confusion_matrix(labels, predictions, normalize="true")
sns.heatmap(cm, annot=True, fmt=".2f")
-
+plt.ylabel("True")
+plt.xlabel("Predicted")
+plt.show()
# %% [markdown]
# # Part 2: Using Integrated Gradients to find what the classifier knows
#
-# In this section we will make a first attempt at highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier.
+# In this section we will make a first attempt at highlighting differences between the "real" and "fake" images that are most important to change the decision of the classifier.
#
# %% [markdown]
# ## Attributions through integrated gradients
#
-# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.
+# Attribution is the process of finding out, based on the output of a neural network, which pixels in the input are (most) responsible for the output. Another way of thinking about it is: which pixels would need to change in order for the network's output to change.
#
# Here we will look at an example of an attribution method called [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients). If you have a bit of time, have a look at this [super fun exploration of attribution methods](https://distill.pub/2020/attribution-baselines/), especially the explanations on Integrated Gradients.
# %% tags=[]
batch_size = 4
-batch = [mnist[i] for i in range(batch_size)]
+batch = []
+for i in range(4):
+ batch.append(next(image for image in mnist if image[1] == i))
x = torch.stack([b[0] for b in batch])
y = torch.tensor([b[1] for b in batch])
x = x.to(device)
@@ -193,7 +197,8 @@ def visualize_attribution(attribution, original_image):
# %% tags=[]
-for attr, im in zip(attributions, x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_attribution(attr, im)
# %% [markdown]
@@ -223,7 +228,8 @@ def visualize_color_attribution(attribution, original_image):
plt.show()
-for attr, im in zip(attributions, x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_color_attribution(attr, im)
# %% [markdown]
@@ -234,7 +240,7 @@ def visualize_color_attribution(attribution, original_image):
# If we didn't know in advance, it is unclear whether the color or the number is the most important feature for the classifier.
# %% [markdown]
#
-# ### Changing the basline
+# ### Changing the baseline
#
# Many existing attribution algorithms are comparative: they show which pixels of the input are responsible for a network output *compared to a baseline*.
# The baseline is often set to an all 0 tensor, but the choice of the baseline affects the output.
@@ -248,7 +254,7 @@ def visualize_color_attribution(attribution, original_image):
# ```
# To get more details about how to include the baseline.
#
-# Try using the code above to change the baseline and see how this affects the output.
+# Try using the code below to change the baseline and see how this affects the output.
#
# 1. Random noise as a baseline
# 2. A blurred/noisy version of the original image as a baseline.
@@ -266,7 +272,8 @@ def visualize_color_attribution(attribution, original_image):
attributions_random = integrated_gradients.attribute(...) # TODO Change
# Plotting
-for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_attribution(attr, im)
# %% tags=["solution"]
@@ -281,7 +288,8 @@ def visualize_color_attribution(attribution, original_image):
)
# Plotting
-for attr, im in zip(attributions_random.cpu().numpy(), x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_color_attribution(attr, im)
# %% [markdown] tags=[]
@@ -299,7 +307,8 @@ def visualize_color_attribution(attribution, original_image):
attributions_blurred = integrated_gradients.attribute(...) # TODO Fill
# Plotting
-for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_color_attribution(attr, im)
# %% tags=["solution"]
@@ -316,7 +325,8 @@ def visualize_color_attribution(attribution, original_image):
)
# Plotting
-for attr, im in zip(attributions_blurred.cpu().numpy(), x.cpu().numpy()):
+for attr, im, lbl in zip(attributions, x.cpu().numpy(), y.cpu().numpy()):
+ print(f"Class {lbl}")
visualize_color_attribution(attr, im)
# %% [markdown] tags=[]
@@ -355,7 +365,7 @@ def visualize_color_attribution(attribution, original_image):
# %% [markdown]
# # Part 3: Train a GAN to Translate Images
#
-# To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.
+# To gain insight into how the trained network classifies images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.
# This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.
#
# **What is a counterfactual?**
@@ -502,6 +512,7 @@ def forward(self, x, y):
mnist, batch_size=32, drop_last=True, shuffle=True
) # We will use the same dataset as before
+
# %% [markdown] tags=[]
# As we stated earlier, it is important to make sure when each network is being trained when working with a GAN.
# Indeed, if we update the weights at the same time, we may lose the adversarial aspect of the training altogether, with information leaking into the generator or discriminator causing them to collaborate when they should be competing!
@@ -512,6 +523,7 @@ def set_requires_grad(module, value=True):
for param in module.parameters():
param.requires_grad = value
+
# %% [markdown] tags=[]
# Another consequence of adversarial training is that it is very unstable.
# While this instability is what leads to finding the best possible solution (which in the case of GANs is on a saddle point), it can also make it difficult to train the model.
@@ -741,9 +753,13 @@ def copy_parameters(source_model, target_model):
idx = 0
fig, axs = plt.subplots(1, 4, figsize=(12, 4))
axs[0].imshow(x[idx].cpu().permute(1, 2, 0).detach().numpy())
+axs[0].set_title("Input image")
axs[1].imshow(x_style[idx].cpu().permute(1, 2, 0).detach().numpy())
+axs[1].set_title("Style image")
axs[2].imshow(x_fake[idx].cpu().permute(1, 2, 0).detach().numpy())
+axs[2].set_title("Generated image")
axs[3].imshow(x_cycled[idx].cpu().permute(1, 2, 0).detach().numpy())
+axs[3].set_title("Cycled image")
for ax in axs:
ax.axis("off")
@@ -859,6 +875,9 @@ def copy_parameters(source_model, target_model):
# %%
cf_cm = confusion_matrix(target_labels, predictions, normalize="true")
sns.heatmap(cf_cm, annot=True, fmt=".2f")
+plt.ylabel("True")
+plt.xlabel("Predicted")
+plt.show()
# %% [markdown] tags=[]
#
Questions
@@ -907,6 +926,7 @@ def copy_parameters(source_model, target_model):
# Generated attributions on integrated gradients
attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)
+
# %% Another visualization function
def visualize_color_attribution_and_counterfactual(
attribution, original_image, counterfactual_image
@@ -927,6 +947,7 @@ def visualize_color_attribution_and_counterfactual(
ax2.axis("off")
plt.show()
+
# %%
for idx in range(batch_size):
print("Source class:", y[idx].item())
@@ -965,8 +986,8 @@ def visualize_color_attribution_and_counterfactual(
#
# So color is important... but not always? What's going on!?
# There is a final piece of information that we can use to solve the puzzle: the style space.
-# %%
-#
Task 6.1: Explore the style space
+# %% [markdown]
+#
Task 5.1: Explore the style space
# Let's take a look at the style space.
# We will use the style encoder to encode the style of the images and then use PCA to visualize it.
#
@@ -988,14 +1009,16 @@ def visualize_color_attribution_and_counterfactual(
styles_pca = pca.fit_transform(styles)
# Plot the PCA
+markers = ["o", "s", "P", "^"]
plt.figure(figsize=(10, 10))
for i in range(4):
plt.scatter(
styles_pca[np.array(labels) == i, 0],
styles_pca[np.array(labels) == i, 1],
+ marker=markers[i],
label=f"Class {i}",
)
-
+plt.legend()
plt.show()
# %% [markdown]
@@ -1013,11 +1036,15 @@ def visualize_color_attribution_and_counterfactual(
# Plot the PCA again!
plt.figure(figsize=(10, 10))
-plt.scatter(
- styles_pca[:, 0],
- styles_pca[:, 1],
- c=normalized_styles,
-)
+for i in range(4):
+ plt.scatter(
+ styles_pca[np.array(labels) == i, 0],
+ styles_pca[np.array(labels) == i, 1],
+ c=normalized_styles[np.array(labels) == i],
+ marker=markers[i],
+ label=f"Class {i}",
+ )
+plt.legend()
plt.show()
# %% [markdown]
#
Questions
@@ -1033,16 +1060,20 @@ def visualize_color_attribution_and_counterfactual(
# Let's get that color, then plot the style space again.
# (Note: once again, no coding needed here, just run the cell and think about the results with the questions below)
#
-# %% tags=["solution"]
-colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]
+# %%
+colors = np.array([np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist])
# Plot the PCA again!
plt.figure(figsize=(10, 10))
-plt.scatter(
- styles_pca[:, 0],
- styles_pca[:, 1],
- c=colors,
-)
+for i in range(4):
+ plt.scatter(
+ styles_pca[np.array(labels) == i, 0],
+ styles_pca[np.array(labels) == i, 1],
+ c=colors[np.array(labels) == i],
+ marker=markers[i],
+ label=f"Class {i}",
+ )
+plt.legend()
plt.show()
# %%