\n",
@@ -529,7 +529,7 @@
},
{
"cell_type": "markdown",
- "id": "31c83033",
+ "id": "0b2d0f2f",
"metadata": {},
"source": [
"
Checkpoint 2
\n",
@@ -549,7 +549,7 @@
},
{
"cell_type": "markdown",
- "id": "12b2601b",
+ "id": "531169e5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -577,7 +577,7 @@
},
{
"cell_type": "markdown",
- "id": "35efae25",
+ "id": "331e56d6",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -600,7 +600,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "55ba1040",
+ "id": "301ee289",
"metadata": {},
"outputs": [],
"source": [
@@ -632,7 +632,7 @@
},
{
"cell_type": "markdown",
- "id": "81ba7c71",
+ "id": "4ce023f6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -647,7 +647,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e17fa41b",
+ "id": "c2698719",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -668,7 +668,7 @@
},
{
"cell_type": "markdown",
- "id": "acc2feba",
+ "id": "16f87104",
"metadata": {
"tags": []
},
@@ -683,7 +683,7 @@
},
{
"cell_type": "markdown",
- "id": "a482f224",
+ "id": "9f1d1149",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -700,7 +700,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f35e34ba",
+ "id": "14e0c929",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -714,7 +714,7 @@
},
{
"cell_type": "markdown",
- "id": "19fdd0a9",
+ "id": "231a5202",
"metadata": {
"lines_to_next_cell": 0
},
@@ -725,7 +725,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7b75bdd6",
+ "id": "c0a2d54d",
"metadata": {},
"outputs": [],
"source": [
@@ -735,7 +735,7 @@
},
{
"cell_type": "markdown",
- "id": "b1dedf50",
+ "id": "4540ef18",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -753,7 +753,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fe716560",
+ "id": "b9fc6671",
"metadata": {
"lines_to_next_cell": 0
},
@@ -765,7 +765,7 @@
},
{
"cell_type": "markdown",
- "id": "a1c9bca2",
+ "id": "196daf45",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -784,7 +784,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d02a2f0c",
+ "id": "1e9ddd12",
"metadata": {},
"outputs": [],
"source": [
@@ -793,7 +793,7 @@
},
{
"cell_type": "markdown",
- "id": "2f5f91ed",
+ "id": "eade7df1",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -809,7 +809,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5bef17c0",
+ "id": "1deb8b8b",
"metadata": {},
"outputs": [],
"source": [
@@ -818,7 +818,7 @@
},
{
"cell_type": "markdown",
- "id": "b8feb471",
+ "id": "ba4a7f7f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -830,7 +830,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1bc928e9",
+ "id": "b5b3d5dc",
"metadata": {},
"outputs": [],
"source": [
@@ -843,7 +843,7 @@
},
{
"cell_type": "markdown",
- "id": "c0a1a77c",
+ "id": "a029e923",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -857,7 +857,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d7a41c68",
+ "id": "54b4de87",
"metadata": {},
"outputs": [],
"source": [
@@ -869,7 +869,7 @@
},
{
"cell_type": "markdown",
- "id": "7ff74b67",
+ "id": "014e484e",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -889,7 +889,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2002bdc0",
+ "id": "f6344c83",
"metadata": {},
"outputs": [],
"source": [
@@ -913,7 +913,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e5303510",
+ "id": "08b7b3af",
"metadata": {},
"outputs": [],
"source": [
@@ -923,7 +923,7 @@
},
{
"cell_type": "markdown",
- "id": "28bd8680",
+ "id": "23fbf680",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -945,7 +945,7 @@
},
{
"cell_type": "markdown",
- "id": "7fbe2fd9",
+ "id": "9cb8281d",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -957,7 +957,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8bb28524",
+ "id": "3b01306d",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1068,7 +1068,7 @@
},
{
"cell_type": "markdown",
- "id": "a540a4d6",
+ "id": "4c25819b",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1080,20 +1080,23 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9b8fa0a1",
+ "id": "0d64d32d",
"metadata": {},
"outputs": [],
"source": [
- "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n",
- "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n",
- "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n",
- "plt.legend()\n",
+ "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n",
+ "ax1.plot(losses[\"cycle\"])\n",
+ "ax1.set_title(\"Cycle loss\")\n",
+ "ax2.plot(losses[\"adv\"])\n",
+ "ax2.set_title(\"Adversarial loss\")\n",
+ "ax3.plot(losses[\"disc\"])\n",
+ "ax3.set_title(\"Discriminator loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "f42e89a9",
+ "id": "326ba2b5",
"metadata": {
"tags": []
},
@@ -1108,7 +1111,7 @@
},
{
"cell_type": "markdown",
- "id": "a34b2f4d",
+ "id": "3e58ca01",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1120,7 +1123,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "810e8d6e",
+ "id": "1c522efa",
"metadata": {},
"outputs": [],
"source": [
@@ -1143,7 +1146,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a3931621",
+ "id": "30b6dac9",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1152,7 +1155,7 @@
},
{
"cell_type": "markdown",
- "id": "910d5ed6",
+ "id": "a3ecbc7b",
"metadata": {
"tags": []
},
@@ -1168,7 +1171,7 @@
},
{
"cell_type": "markdown",
- "id": "d75728f1",
+ "id": "e6bdaecb",
"metadata": {
"tags": []
},
@@ -1178,7 +1181,7 @@
},
{
"cell_type": "markdown",
- "id": "46ac6b2d",
+ "id": "7f994579",
"metadata": {
"tags": []
},
@@ -1195,13 +1198,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3541f664",
+ "id": "4e4fe83e",
"metadata": {
"title": "Loading the test dataset"
},
"outputs": [],
"source": [
- "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n",
+ "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n",
"prototypes = {}\n",
"\n",
"\n",
@@ -1215,7 +1218,7 @@
},
{
"cell_type": "markdown",
- "id": "d8d02278",
+ "id": "049a6b22",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1227,7 +1230,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "220450b4",
+ "id": "639f37e2",
"metadata": {},
"outputs": [],
"source": [
@@ -1240,7 +1243,7 @@
},
{
"cell_type": "markdown",
- "id": "d7c8d8a8",
+ "id": "02cb705b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1250,7 +1253,7 @@
},
{
"cell_type": "markdown",
- "id": "f607ce7c",
+ "id": "f41a6ce5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1268,7 +1271,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "10b77d39",
+ "id": "282f8858",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1304,7 +1307,7 @@
},
{
"cell_type": "markdown",
- "id": "95379712",
+ "id": "ebffc15f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1316,7 +1319,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "df4f63b4",
+ "id": "baac8071",
"metadata": {},
"outputs": [],
"source": [
@@ -1329,7 +1332,7 @@
},
{
"cell_type": "markdown",
- "id": "f7dd387e",
+ "id": "88e7ea0c",
"metadata": {
"tags": []
},
@@ -1344,7 +1347,7 @@
},
{
"cell_type": "markdown",
- "id": "bfeaf7d1",
+ "id": "25972c49",
"metadata": {
"tags": []
},
@@ -1355,7 +1358,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9dec938b",
+ "id": "12d49576",
"metadata": {},
"outputs": [],
"source": [
@@ -1369,7 +1372,7 @@
},
{
"cell_type": "markdown",
- "id": "bbcf6338",
+ "id": "8e6f04f3",
"metadata": {
"tags": []
},
@@ -1384,7 +1387,7 @@
},
{
"cell_type": "markdown",
- "id": "866b85d4",
+ "id": "50728ff2",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1399,7 +1402,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c3bd150",
+ "id": "dedc0f83",
"metadata": {},
"outputs": [],
"source": [
@@ -1419,7 +1422,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a6c9d35d",
+ "id": "5446e796",
"metadata": {
"title": "Another visualization function"
},
@@ -1448,7 +1451,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "355d3691",
+ "id": "5e2fb59e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1464,7 +1467,7 @@
},
{
"cell_type": "markdown",
- "id": "d3717907",
+ "id": "b393a8f1",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1480,7 +1483,7 @@
},
{
"cell_type": "markdown",
- "id": "4063399b",
+ "id": "5ba47fc6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1495,7 +1498,7 @@
},
{
"cell_type": "markdown",
- "id": "587f4083",
+ "id": "2654d788",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1518,7 +1521,7 @@
},
{
"cell_type": "markdown",
- "id": "499c184e",
+ "id": "76559366",
"metadata": {},
"source": [
"
Task 5.1: Explore the style space
\n",
@@ -1530,7 +1533,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "09065024",
+ "id": "f1fdb890",
"metadata": {},
"outputs": [],
"source": [
@@ -1565,7 +1568,7 @@
},
{
"cell_type": "markdown",
- "id": "d6f40f81",
+ "id": "b666769e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1581,7 +1584,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "28f9efd8",
+ "id": "e61d0c9b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1608,7 +1611,7 @@
},
{
"cell_type": "markdown",
- "id": "35eb9e2b",
+ "id": "6f1d3ff3",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1622,7 +1625,7 @@
},
{
"cell_type": "markdown",
- "id": "b7e631b9",
+ "id": "90889399",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1639,7 +1642,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d7bf9f03",
+ "id": "f67b3f90",
"metadata": {},
"outputs": [],
"source": [
@@ -1662,7 +1665,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6f2fa456",
+ "id": "b18b2b81",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1671,7 +1674,7 @@
},
{
"cell_type": "markdown",
- "id": "4c030783",
+ "id": "bf87e80b",
"metadata": {},
"source": [
"
Questions
\n",
@@ -1683,7 +1686,7 @@
},
{
"cell_type": "markdown",
- "id": "392618f7",
+ "id": "11aafcc5",
"metadata": {},
"source": [
"
Checkpoint 5
\n",
diff --git a/solution.ipynb b/solution.ipynb
index a345b23..b0b9e5a 100644
--- a/solution.ipynb
+++ b/solution.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "cabeeff7",
+ "id": "30c11df5",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -29,7 +29,7 @@
},
{
"cell_type": "markdown",
- "id": "a6549d6e",
+ "id": "ec2899d4",
"metadata": {
"lines_to_next_cell": 0
},
@@ -41,7 +41,7 @@
},
{
"cell_type": "markdown",
- "id": "af277573",
+ "id": "2c084b97",
"metadata": {},
"source": [
"\n",
@@ -54,7 +54,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d133ee66",
+ "id": "9d26a8bb",
"metadata": {
"lines_to_next_cell": 0
},
@@ -63,12 +63,12 @@
"# loading the data\n",
"from classifier.data import ColoredMNIST\n",
"\n",
- "mnist = ColoredMNIST(\"data\", download=True)"
+ "mnist = ColoredMNIST(\"extras/data\", download=True)"
]
},
{
"cell_type": "markdown",
- "id": "7bf9a7d1",
+ "id": "f8a5937c",
"metadata": {
"lines_to_next_cell": 0
},
@@ -84,7 +84,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0d4c5c7f",
+ "id": "9c0ce960",
"metadata": {},
"outputs": [],
"source": [
@@ -102,7 +102,7 @@
},
{
"cell_type": "markdown",
- "id": "4189496b",
+ "id": "0cb834e5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -113,7 +113,7 @@
},
{
"cell_type": "markdown",
- "id": "ec85ffc9",
+ "id": "a32035d7",
"metadata": {
"lines_to_next_cell": 0
},
@@ -130,7 +130,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bbb01724",
+ "id": "0146821b",
"metadata": {
"tags": [
"solution"
@@ -154,7 +154,7 @@
},
{
"cell_type": "markdown",
- "id": "ebf14527",
+ "id": "6ecddeb8",
"metadata": {
"lines_to_next_cell": 0
},
@@ -165,7 +165,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0fa46d9a",
+ "id": "c271ecd9",
"metadata": {},
"outputs": [],
"source": [
@@ -173,7 +173,7 @@
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sns\n",
"\n",
- "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n",
+ "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n",
"dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n",
"\n",
"labels = []\n",
@@ -192,7 +192,7 @@
},
{
"cell_type": "markdown",
- "id": "35845bc8",
+ "id": "46a684f4",
"metadata": {},
"source": [
"# Part 2: Using Integrated Gradients to find what the classifier knows\n",
@@ -202,7 +202,7 @@
},
{
"cell_type": "markdown",
- "id": "9d861e84",
+ "id": "0255c073",
"metadata": {},
"source": [
"## Attributions through integrated gradients\n",
@@ -215,7 +215,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "811b9852",
+ "id": "e5b162b7",
"metadata": {
"tags": []
},
@@ -233,7 +233,7 @@
},
{
"cell_type": "markdown",
- "id": "38c2b5f2",
+ "id": "6d418ea1",
"metadata": {
"tags": []
},
@@ -249,7 +249,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fc427029",
+ "id": "f93e8067",
"metadata": {
"tags": [
"solution"
@@ -273,7 +273,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "422bc189",
+ "id": "e4ba6b3a",
"metadata": {
"tags": []
},
@@ -286,7 +286,7 @@
},
{
"cell_type": "markdown",
- "id": "677d8c4a",
+ "id": "56e432ae",
"metadata": {
"lines_to_next_cell": 2,
"tags": []
@@ -298,7 +298,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c13d35fb",
+ "id": "9561d46f",
"metadata": {
"tags": []
},
@@ -326,7 +326,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "70a3b3b3",
+ "id": "a55fe8ec",
"metadata": {
"tags": []
},
@@ -339,7 +339,7 @@
},
{
"cell_type": "markdown",
- "id": "916906ac",
+ "id": "1d8c03a0",
"metadata": {
"lines_to_next_cell": 2
},
@@ -353,7 +353,7 @@
},
{
"cell_type": "markdown",
- "id": "00494aec",
+ "id": "2a24c70a",
"metadata": {
"lines_to_next_cell": 0
},
@@ -366,7 +366,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "88c9d18e",
+ "id": "6e875faa",
"metadata": {},
"outputs": [],
"source": [
@@ -391,7 +391,7 @@
},
{
"cell_type": "markdown",
- "id": "2110738d",
+ "id": "3f73608f",
"metadata": {
"lines_to_next_cell": 0
},
@@ -405,7 +405,7 @@
},
{
"cell_type": "markdown",
- "id": "3292fbe5",
+ "id": "a8e71c0b",
"metadata": {},
"source": [
"\n",
@@ -431,7 +431,7 @@
},
{
"cell_type": "markdown",
- "id": "46c075dc",
+ "id": "dbb04b6f",
"metadata": {},
"source": [
"
Task 2.3: Use random noise as a baseline
\n",
@@ -443,7 +443,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fbf0e8de",
+ "id": "cde2c2ff",
"metadata": {
"tags": [
"solution"
@@ -469,7 +469,7 @@
},
{
"cell_type": "markdown",
- "id": "88239eb5",
+ "id": "bf7e934c",
"metadata": {
"tags": []
},
@@ -483,7 +483,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7ba51c8b",
+ "id": "a0cb195e",
"metadata": {
"tags": [
"solution"
@@ -511,7 +511,7 @@
},
{
"cell_type": "markdown",
- "id": "3a52e78e",
+ "id": "db46361b",
"metadata": {
"tags": []
},
@@ -527,7 +527,7 @@
},
{
"cell_type": "markdown",
- "id": "bf2263d6",
+ "id": "e9105812",
"metadata": {},
"source": [
"
BONUS Task: Using different attributions.
\n",
@@ -541,7 +541,7 @@
},
{
"cell_type": "markdown",
- "id": "31c83033",
+ "id": "0b2d0f2f",
"metadata": {},
"source": [
"
Checkpoint 2
\n",
@@ -561,7 +561,7 @@
},
{
"cell_type": "markdown",
- "id": "12b2601b",
+ "id": "531169e5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -589,7 +589,7 @@
},
{
"cell_type": "markdown",
- "id": "35efae25",
+ "id": "331e56d6",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -612,7 +612,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "55ba1040",
+ "id": "301ee289",
"metadata": {},
"outputs": [],
"source": [
@@ -644,7 +644,7 @@
},
{
"cell_type": "markdown",
- "id": "81ba7c71",
+ "id": "4ce023f6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -659,7 +659,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ded9c5d3",
+ "id": "b491022a",
"metadata": {
"tags": [
"solution"
@@ -676,7 +676,7 @@
},
{
"cell_type": "markdown",
- "id": "acc2feba",
+ "id": "16f87104",
"metadata": {
"tags": []
},
@@ -691,7 +691,7 @@
},
{
"cell_type": "markdown",
- "id": "a482f224",
+ "id": "9f1d1149",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -708,7 +708,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d48de07d",
+ "id": "71695d57",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -722,7 +722,7 @@
},
{
"cell_type": "markdown",
- "id": "19fdd0a9",
+ "id": "231a5202",
"metadata": {
"lines_to_next_cell": 0
},
@@ -733,7 +733,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7b75bdd6",
+ "id": "c0a2d54d",
"metadata": {},
"outputs": [],
"source": [
@@ -743,7 +743,7 @@
},
{
"cell_type": "markdown",
- "id": "b1dedf50",
+ "id": "4540ef18",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -761,7 +761,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fe716560",
+ "id": "b9fc6671",
"metadata": {
"lines_to_next_cell": 0
},
@@ -773,7 +773,7 @@
},
{
"cell_type": "markdown",
- "id": "a1c9bca2",
+ "id": "196daf45",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -792,7 +792,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d02a2f0c",
+ "id": "1e9ddd12",
"metadata": {},
"outputs": [],
"source": [
@@ -801,7 +801,7 @@
},
{
"cell_type": "markdown",
- "id": "2f5f91ed",
+ "id": "eade7df1",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -817,7 +817,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5bef17c0",
+ "id": "1deb8b8b",
"metadata": {},
"outputs": [],
"source": [
@@ -826,7 +826,7 @@
},
{
"cell_type": "markdown",
- "id": "b8feb471",
+ "id": "ba4a7f7f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -838,7 +838,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1bc928e9",
+ "id": "b5b3d5dc",
"metadata": {},
"outputs": [],
"source": [
@@ -851,7 +851,7 @@
},
{
"cell_type": "markdown",
- "id": "c0a1a77c",
+ "id": "a029e923",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -865,7 +865,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d7a41c68",
+ "id": "54b4de87",
"metadata": {},
"outputs": [],
"source": [
@@ -877,7 +877,7 @@
},
{
"cell_type": "markdown",
- "id": "7ff74b67",
+ "id": "014e484e",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -897,7 +897,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2002bdc0",
+ "id": "f6344c83",
"metadata": {},
"outputs": [],
"source": [
@@ -921,7 +921,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e5303510",
+ "id": "08b7b3af",
"metadata": {},
"outputs": [],
"source": [
@@ -931,7 +931,7 @@
},
{
"cell_type": "markdown",
- "id": "28bd8680",
+ "id": "23fbf680",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -953,7 +953,7 @@
},
{
"cell_type": "markdown",
- "id": "7fbe2fd9",
+ "id": "9cb8281d",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -965,7 +965,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e66a1fa9",
+ "id": "699b3220",
"metadata": {
"lines_to_next_cell": 2,
"tags": [
@@ -1035,7 +1035,7 @@
},
{
"cell_type": "markdown",
- "id": "a540a4d6",
+ "id": "4c25819b",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1047,20 +1047,23 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9b8fa0a1",
+ "id": "0d64d32d",
"metadata": {},
"outputs": [],
"source": [
- "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n",
- "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n",
- "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n",
- "plt.legend()\n",
+ "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n",
+ "ax1.plot(losses[\"cycle\"])\n",
+ "ax1.set_title(\"Cycle loss\")\n",
+ "ax2.plot(losses[\"adv\"])\n",
+ "ax2.set_title(\"Adversarial loss\")\n",
+ "ax3.plot(losses[\"disc\"])\n",
+ "ax3.set_title(\"Discriminator loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "f42e89a9",
+ "id": "326ba2b5",
"metadata": {
"tags": []
},
@@ -1075,7 +1078,7 @@
},
{
"cell_type": "markdown",
- "id": "a34b2f4d",
+ "id": "3e58ca01",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1087,7 +1090,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "810e8d6e",
+ "id": "1c522efa",
"metadata": {},
"outputs": [],
"source": [
@@ -1110,7 +1113,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a3931621",
+ "id": "30b6dac9",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1119,7 +1122,7 @@
},
{
"cell_type": "markdown",
- "id": "910d5ed6",
+ "id": "a3ecbc7b",
"metadata": {
"tags": []
},
@@ -1135,7 +1138,7 @@
},
{
"cell_type": "markdown",
- "id": "d75728f1",
+ "id": "e6bdaecb",
"metadata": {
"tags": []
},
@@ -1145,7 +1148,7 @@
},
{
"cell_type": "markdown",
- "id": "46ac6b2d",
+ "id": "7f994579",
"metadata": {
"tags": []
},
@@ -1162,13 +1165,13 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3541f664",
+ "id": "4e4fe83e",
"metadata": {
"title": "Loading the test dataset"
},
"outputs": [],
"source": [
- "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n",
+ "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n",
"prototypes = {}\n",
"\n",
"\n",
@@ -1182,7 +1185,7 @@
},
{
"cell_type": "markdown",
- "id": "d8d02278",
+ "id": "049a6b22",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1194,7 +1197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "220450b4",
+ "id": "639f37e2",
"metadata": {},
"outputs": [],
"source": [
@@ -1207,7 +1210,7 @@
},
{
"cell_type": "markdown",
- "id": "d7c8d8a8",
+ "id": "02cb705b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1217,7 +1220,7 @@
},
{
"cell_type": "markdown",
- "id": "f607ce7c",
+ "id": "f41a6ce5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1235,7 +1238,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3d20c0da",
+ "id": "00616e67",
"metadata": {
"tags": [
"solution"
@@ -1272,7 +1275,7 @@
},
{
"cell_type": "markdown",
- "id": "95379712",
+ "id": "ebffc15f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1284,7 +1287,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "df4f63b4",
+ "id": "baac8071",
"metadata": {},
"outputs": [],
"source": [
@@ -1297,7 +1300,7 @@
},
{
"cell_type": "markdown",
- "id": "f7dd387e",
+ "id": "88e7ea0c",
"metadata": {
"tags": []
},
@@ -1312,7 +1315,7 @@
},
{
"cell_type": "markdown",
- "id": "bfeaf7d1",
+ "id": "25972c49",
"metadata": {
"tags": []
},
@@ -1323,7 +1326,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9dec938b",
+ "id": "12d49576",
"metadata": {},
"outputs": [],
"source": [
@@ -1337,7 +1340,7 @@
},
{
"cell_type": "markdown",
- "id": "bbcf6338",
+ "id": "8e6f04f3",
"metadata": {
"tags": []
},
@@ -1352,7 +1355,7 @@
},
{
"cell_type": "markdown",
- "id": "866b85d4",
+ "id": "50728ff2",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1367,7 +1370,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c3bd150",
+ "id": "dedc0f83",
"metadata": {},
"outputs": [],
"source": [
@@ -1387,7 +1390,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a6c9d35d",
+ "id": "5446e796",
"metadata": {
"title": "Another visualization function"
},
@@ -1416,7 +1419,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "355d3691",
+ "id": "5e2fb59e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1432,7 +1435,7 @@
},
{
"cell_type": "markdown",
- "id": "d3717907",
+ "id": "b393a8f1",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1448,7 +1451,7 @@
},
{
"cell_type": "markdown",
- "id": "4063399b",
+ "id": "5ba47fc6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1463,7 +1466,7 @@
},
{
"cell_type": "markdown",
- "id": "587f4083",
+ "id": "2654d788",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1486,7 +1489,7 @@
},
{
"cell_type": "markdown",
- "id": "499c184e",
+ "id": "76559366",
"metadata": {},
"source": [
"
Task 5.1: Explore the style space
\n",
@@ -1498,7 +1501,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "09065024",
+ "id": "f1fdb890",
"metadata": {},
"outputs": [],
"source": [
@@ -1533,7 +1536,7 @@
},
{
"cell_type": "markdown",
- "id": "d6f40f81",
+ "id": "b666769e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1549,7 +1552,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "28f9efd8",
+ "id": "e61d0c9b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1576,7 +1579,7 @@
},
{
"cell_type": "markdown",
- "id": "35eb9e2b",
+ "id": "6f1d3ff3",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1590,7 +1593,7 @@
},
{
"cell_type": "markdown",
- "id": "b7e631b9",
+ "id": "90889399",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1607,7 +1610,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d7bf9f03",
+ "id": "f67b3f90",
"metadata": {},
"outputs": [],
"source": [
@@ -1630,7 +1633,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6f2fa456",
+ "id": "b18b2b81",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1639,7 +1642,7 @@
},
{
"cell_type": "markdown",
- "id": "4c030783",
+ "id": "bf87e80b",
"metadata": {},
"source": [
"
Questions
\n",
@@ -1651,7 +1654,7 @@
},
{
"cell_type": "markdown",
- "id": "392618f7",
+ "id": "11aafcc5",
"metadata": {},
"source": [
"
Checkpoint 5
\n",
@@ -1669,7 +1672,7 @@
},
{
"cell_type": "markdown",
- "id": "609323f6",
+ "id": "a5c8b45e",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1684,7 +1687,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c69ea188",
+ "id": "45e17541",
"metadata": {
"tags": [
"solution"