diff --git a/exercise.ipynb b/exercise.ipynb index 665f54b..92007b0 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cabeeff7", + "id": "30c11df5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "a6549d6e", + "id": "ec2899d4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "af277573", + "id": "2c084b97", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d133ee66", + "id": "9d26a8bb", "metadata": { "lines_to_next_cell": 0 }, @@ -63,12 +63,12 @@ "# loading the data\n", "from classifier.data import ColoredMNIST\n", "\n", - "mnist = ColoredMNIST(\"data\", download=True)" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "7bf9a7d1", + "id": "f8a5937c", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d4c5c7f", + "id": "9c0ce960", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "4189496b", + "id": "0cb834e5", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "ec85ffc9", + "id": "a32035d7", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85c5021a", + "id": "47684cce", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "ebf14527", + "id": "6ecddeb8", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fa46d9a", + "id": "c271ecd9", "metadata": {}, "outputs": [], "source": [ @@ -174,7 +174,7 @@ "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "\n", - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", "labels = []\n", @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "35845bc8", + "id": "46a684f4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "9d861e84", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "811b9852", + "id": "e5b162b7", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "38c2b5f2", + "id": "6d418ea1", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e678e018", + "id": "5ce086ee", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "422bc189", + "id": "e4ba6b3a", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "677d8c4a", + "id": "56e432ae", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c13d35fb", + "id": "9561d46f", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70a3b3b3", + "id": "a55fe8ec", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "916906ac", + "id": "1d8c03a0", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "00494aec", + "id": "2a24c70a", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c9d18e", + "id": "6e875faa", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "2110738d", + "id": "3f73608f", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "3292fbe5", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "46c075dc", + "id": "dbb04b6f", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bba71667", + "id": "2fc8f45c", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "88239eb5", + "id": "bf7e934c", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d81a759", + "id": "2e14f754", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "3a52e78e", + "id": "db46361b", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "bf2263d6", + "id": "e9105812", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "31c83033", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "12b2601b", + "id": "531169e5", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "35efae25", + "id": "331e56d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55ba1040", + "id": "301ee289", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "81ba7c71", + "id": "4ce023f6", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e17fa41b", + "id": "c2698719", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "acc2feba", + "id": "16f87104", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "a482f224", + "id": "9f1d1149", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f35e34ba", + "id": "14e0c929", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "19fdd0a9", + "id": "231a5202", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b75bdd6", + "id": "c0a2d54d", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "b1dedf50", + "id": "4540ef18", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe716560", + "id": "b9fc6671", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "a1c9bca2", + "id": "196daf45", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d02a2f0c", + "id": "1e9ddd12", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "2f5f91ed", + "id": "eade7df1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bef17c0", + "id": "1deb8b8b", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "b8feb471", + "id": "ba4a7f7f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bc928e9", + "id": "b5b3d5dc", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "c0a1a77c", + "id": "a029e923", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7a41c68", + "id": "54b4de87", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "7ff74b67", + "id": "014e484e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2002bdc0", + "id": "f6344c83", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5303510", + "id": "08b7b3af", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "28bd8680", + "id": "23fbf680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "7fbe2fd9", + "id": "9cb8281d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8bb28524", + "id": "3b01306d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "a540a4d6", + "id": "4c25819b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,20 +1080,23 @@ { "cell_type": "code", "execution_count": null, - "id": "9b8fa0a1", + "id": "0d64d32d", "metadata": {}, "outputs": [], "source": [ - "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", - "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", - "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", - "plt.legend()\n", + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f42e89a9", + "id": "326ba2b5", "metadata": { "tags": [] }, @@ -1108,7 +1111,7 @@ }, { "cell_type": "markdown", - "id": "a34b2f4d", + "id": "3e58ca01", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1120,7 +1123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "810e8d6e", + "id": "1c522efa", "metadata": {}, "outputs": [], "source": [ @@ -1143,7 +1146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3931621", + "id": "30b6dac9", "metadata": { "lines_to_next_cell": 0 }, @@ -1152,7 +1155,7 @@ }, { "cell_type": "markdown", - "id": "910d5ed6", + "id": "a3ecbc7b", "metadata": { "tags": [] }, @@ -1168,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "d75728f1", + "id": "e6bdaecb", "metadata": { "tags": [] }, @@ -1178,7 +1181,7 @@ }, { "cell_type": "markdown", - "id": "46ac6b2d", + "id": "7f994579", "metadata": { "tags": [] }, @@ -1195,13 +1198,13 @@ { "cell_type": "code", "execution_count": null, - "id": "3541f664", + "id": "4e4fe83e", "metadata": { "title": "Loading the test dataset" }, "outputs": [], "source": [ - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "prototypes = {}\n", "\n", "\n", @@ -1215,7 +1218,7 @@ }, { "cell_type": "markdown", - "id": "d8d02278", + "id": "049a6b22", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1227,7 +1230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "220450b4", + "id": "639f37e2", "metadata": {}, "outputs": [], "source": [ @@ -1240,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "d7c8d8a8", + "id": "02cb705b", "metadata": { "lines_to_next_cell": 0 }, @@ -1250,7 +1253,7 @@ }, { "cell_type": "markdown", - "id": "f607ce7c", + "id": "f41a6ce5", "metadata": { "lines_to_next_cell": 0 }, @@ -1268,7 +1271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10b77d39", + "id": "282f8858", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1304,7 +1307,7 @@ }, { "cell_type": "markdown", - "id": "95379712", + "id": "ebffc15f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1316,7 +1319,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df4f63b4", + "id": "baac8071", "metadata": {}, "outputs": [], "source": [ @@ -1329,7 +1332,7 @@ }, { "cell_type": "markdown", - "id": "f7dd387e", + "id": "88e7ea0c", "metadata": { "tags": [] }, @@ -1344,7 +1347,7 @@ }, { "cell_type": "markdown", - "id": "bfeaf7d1", + "id": "25972c49", "metadata": { "tags": [] }, @@ -1355,7 +1358,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dec938b", + "id": "12d49576", "metadata": {}, "outputs": [], "source": [ @@ -1369,7 +1372,7 @@ }, { "cell_type": "markdown", - "id": "bbcf6338", + "id": "8e6f04f3", "metadata": { "tags": [] }, @@ -1384,7 +1387,7 @@ }, { "cell_type": "markdown", - "id": "866b85d4", + "id": "50728ff2", "metadata": { "lines_to_next_cell": 0 }, @@ -1399,7 +1402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2c3bd150", + "id": "dedc0f83", "metadata": {}, "outputs": [], "source": [ @@ -1419,7 +1422,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a6c9d35d", + "id": "5446e796", "metadata": { "title": "Another visualization function" }, @@ -1448,7 +1451,7 @@ { "cell_type": "code", "execution_count": null, - "id": "355d3691", + "id": "5e2fb59e", "metadata": { "lines_to_next_cell": 0 }, @@ -1464,7 +1467,7 @@ }, { "cell_type": "markdown", - "id": "d3717907", + "id": "b393a8f1", "metadata": { "lines_to_next_cell": 0 }, @@ -1480,7 +1483,7 @@ }, { "cell_type": "markdown", - "id": "4063399b", + "id": "5ba47fc6", "metadata": { "lines_to_next_cell": 0 }, @@ -1495,7 +1498,7 @@ }, { "cell_type": "markdown", - "id": "587f4083", + "id": "2654d788", "metadata": { "lines_to_next_cell": 0 }, @@ -1518,7 +1521,7 @@ }, { "cell_type": "markdown", - "id": "499c184e", + "id": "76559366", "metadata": {}, "source": [ "

Task 5.1: Explore the style space

\n", @@ -1530,7 +1533,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09065024", + "id": "f1fdb890", "metadata": {}, "outputs": [], "source": [ @@ -1565,7 +1568,7 @@ }, { "cell_type": "markdown", - "id": "d6f40f81", + "id": "b666769e", "metadata": { "lines_to_next_cell": 0 }, @@ -1581,7 +1584,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28f9efd8", + "id": "e61d0c9b", "metadata": { "lines_to_next_cell": 0 }, @@ -1608,7 +1611,7 @@ }, { "cell_type": "markdown", - "id": "35eb9e2b", + "id": "6f1d3ff3", "metadata": { "lines_to_next_cell": 0 }, @@ -1622,7 +1625,7 @@ }, { "cell_type": "markdown", - "id": "b7e631b9", + "id": "90889399", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1642,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7bf9f03", + "id": "f67b3f90", "metadata": {}, "outputs": [], "source": [ @@ -1662,7 +1665,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f2fa456", + "id": "b18b2b81", "metadata": { "lines_to_next_cell": 0 }, @@ -1671,7 +1674,7 @@ }, { "cell_type": "markdown", - "id": "4c030783", + "id": "bf87e80b", "metadata": {}, "source": [ "

Questions

\n", @@ -1683,7 +1686,7 @@ }, { "cell_type": "markdown", - "id": "392618f7", + "id": "11aafcc5", "metadata": {}, "source": [ "

Checkpoint 5

\n", diff --git a/solution.ipynb b/solution.ipynb index a345b23..b0b9e5a 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "cabeeff7", + "id": "30c11df5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "a6549d6e", + "id": "ec2899d4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "af277573", + "id": "2c084b97", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d133ee66", + "id": "9d26a8bb", "metadata": { "lines_to_next_cell": 0 }, @@ -63,12 +63,12 @@ "# loading the data\n", "from classifier.data import ColoredMNIST\n", "\n", - "mnist = ColoredMNIST(\"data\", download=True)" + "mnist = ColoredMNIST(\"extras/data\", download=True)" ] }, { "cell_type": "markdown", - "id": "7bf9a7d1", + "id": "f8a5937c", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d4c5c7f", + "id": "9c0ce960", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "4189496b", + "id": "0cb834e5", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "ec85ffc9", + "id": "a32035d7", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bbb01724", + "id": "0146821b", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "ebf14527", + "id": "6ecddeb8", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0fa46d9a", + "id": "c271ecd9", "metadata": {}, "outputs": [], "source": [ @@ -173,7 +173,7 @@ "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "\n", - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", "\n", "labels = []\n", @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "35845bc8", + "id": "46a684f4", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "9d861e84", + "id": "0255c073", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "811b9852", + "id": "e5b162b7", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "38c2b5f2", + "id": "6d418ea1", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc427029", + "id": "f93e8067", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "422bc189", + "id": "e4ba6b3a", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "677d8c4a", + "id": "56e432ae", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c13d35fb", + "id": "9561d46f", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70a3b3b3", + "id": "a55fe8ec", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "916906ac", + "id": "1d8c03a0", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "00494aec", + "id": "2a24c70a", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c9d18e", + "id": "6e875faa", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "2110738d", + "id": "3f73608f", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "3292fbe5", + "id": "a8e71c0b", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "46c075dc", + "id": "dbb04b6f", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbf0e8de", + "id": "cde2c2ff", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "88239eb5", + "id": "bf7e934c", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ba51c8b", + "id": "a0cb195e", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "3a52e78e", + "id": "db46361b", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "bf2263d6", + "id": "e9105812", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "31c83033", + "id": "0b2d0f2f", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "12b2601b", + "id": "531169e5", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "35efae25", + "id": "331e56d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55ba1040", + "id": "301ee289", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "81ba7c71", + "id": "4ce023f6", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ded9c5d3", + "id": "b491022a", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "acc2feba", + "id": "16f87104", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "a482f224", + "id": "9f1d1149", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d48de07d", + "id": "71695d57", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "19fdd0a9", + "id": "231a5202", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b75bdd6", + "id": "c0a2d54d", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "b1dedf50", + "id": "4540ef18", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe716560", + "id": "b9fc6671", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "a1c9bca2", + "id": "196daf45", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d02a2f0c", + "id": "1e9ddd12", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "2f5f91ed", + "id": "eade7df1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bef17c0", + "id": "1deb8b8b", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "b8feb471", + "id": "ba4a7f7f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bc928e9", + "id": "b5b3d5dc", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "c0a1a77c", + "id": "a029e923", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7a41c68", + "id": "54b4de87", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "7ff74b67", + "id": "014e484e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2002bdc0", + "id": "f6344c83", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5303510", + "id": "08b7b3af", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "28bd8680", + "id": "23fbf680", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "7fbe2fd9", + "id": "9cb8281d", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e66a1fa9", + "id": "699b3220", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "a540a4d6", + "id": "4c25819b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,20 +1047,23 @@ { "cell_type": "code", "execution_count": null, - "id": "9b8fa0a1", + "id": "0d64d32d", "metadata": {}, "outputs": [], "source": [ - "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", - "plt.plot(losses[\"adv\"], label=\"Adversarial loss\")\n", - "plt.plot(losses[\"disc\"], label=\"Discriminator loss\")\n", - "plt.legend()\n", + "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.plot(losses[\"cycle\"])\n", + "ax1.set_title(\"Cycle loss\")\n", + "ax2.plot(losses[\"adv\"])\n", + "ax2.set_title(\"Adversarial loss\")\n", + "ax3.plot(losses[\"disc\"])\n", + "ax3.set_title(\"Discriminator loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "f42e89a9", + "id": "326ba2b5", "metadata": { "tags": [] }, @@ -1075,7 +1078,7 @@ }, { "cell_type": "markdown", - "id": "a34b2f4d", + "id": "3e58ca01", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1087,7 +1090,7 @@ { "cell_type": "code", "execution_count": null, - "id": "810e8d6e", + "id": "1c522efa", "metadata": {}, "outputs": [], "source": [ @@ -1110,7 +1113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3931621", + "id": "30b6dac9", "metadata": { "lines_to_next_cell": 0 }, @@ -1119,7 +1122,7 @@ }, { "cell_type": "markdown", - "id": "910d5ed6", + "id": "a3ecbc7b", "metadata": { "tags": [] }, @@ -1135,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "d75728f1", + "id": "e6bdaecb", "metadata": { "tags": [] }, @@ -1145,7 +1148,7 @@ }, { "cell_type": "markdown", - "id": "46ac6b2d", + "id": "7f994579", "metadata": { "tags": [] }, @@ -1162,13 +1165,13 @@ { "cell_type": "code", "execution_count": null, - "id": "3541f664", + "id": "4e4fe83e", "metadata": { "title": "Loading the test dataset" }, "outputs": [], "source": [ - "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "test_mnist = ColoredMNIST(\"extras/data\", download=True, train=False)\n", "prototypes = {}\n", "\n", "\n", @@ -1182,7 +1185,7 @@ }, { "cell_type": "markdown", - "id": "d8d02278", + "id": "049a6b22", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1194,7 +1197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "220450b4", + "id": "639f37e2", "metadata": {}, "outputs": [], "source": [ @@ -1207,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "d7c8d8a8", + "id": "02cb705b", "metadata": { "lines_to_next_cell": 0 }, @@ -1217,7 +1220,7 @@ }, { "cell_type": "markdown", - "id": "f607ce7c", + "id": "f41a6ce5", "metadata": { "lines_to_next_cell": 0 }, @@ -1235,7 +1238,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d20c0da", + "id": "00616e67", "metadata": { "tags": [ "solution" @@ -1272,7 +1275,7 @@ }, { "cell_type": "markdown", - "id": "95379712", + "id": "ebffc15f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1284,7 +1287,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df4f63b4", + "id": "baac8071", "metadata": {}, "outputs": [], "source": [ @@ -1297,7 +1300,7 @@ }, { "cell_type": "markdown", - "id": "f7dd387e", + "id": "88e7ea0c", "metadata": { "tags": [] }, @@ -1312,7 +1315,7 @@ }, { "cell_type": "markdown", - "id": "bfeaf7d1", + "id": "25972c49", "metadata": { "tags": [] }, @@ -1323,7 +1326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9dec938b", + "id": "12d49576", "metadata": {}, "outputs": [], "source": [ @@ -1337,7 +1340,7 @@ }, { "cell_type": "markdown", - "id": "bbcf6338", + "id": "8e6f04f3", "metadata": { "tags": [] }, @@ -1352,7 +1355,7 @@ }, { "cell_type": "markdown", - "id": "866b85d4", + "id": "50728ff2", "metadata": { "lines_to_next_cell": 0 }, @@ -1367,7 +1370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2c3bd150", + "id": "dedc0f83", "metadata": {}, "outputs": [], "source": [ @@ -1387,7 +1390,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a6c9d35d", + "id": "5446e796", "metadata": { "title": "Another visualization function" }, @@ -1416,7 +1419,7 @@ { "cell_type": "code", "execution_count": null, - "id": "355d3691", + "id": "5e2fb59e", "metadata": { "lines_to_next_cell": 0 }, @@ -1432,7 +1435,7 @@ }, { "cell_type": "markdown", - "id": "d3717907", + "id": "b393a8f1", "metadata": { "lines_to_next_cell": 0 }, @@ -1448,7 +1451,7 @@ }, { "cell_type": "markdown", - "id": "4063399b", + "id": "5ba47fc6", "metadata": { "lines_to_next_cell": 0 }, @@ -1463,7 +1466,7 @@ }, { "cell_type": "markdown", - "id": "587f4083", + "id": "2654d788", "metadata": { "lines_to_next_cell": 0 }, @@ -1486,7 +1489,7 @@ }, { "cell_type": "markdown", - "id": "499c184e", + "id": "76559366", "metadata": {}, "source": [ "

Task 5.1: Explore the style space

\n", @@ -1498,7 +1501,7 @@ { "cell_type": "code", "execution_count": null, - "id": "09065024", + "id": "f1fdb890", "metadata": {}, "outputs": [], "source": [ @@ -1533,7 +1536,7 @@ }, { "cell_type": "markdown", - "id": "d6f40f81", + "id": "b666769e", "metadata": { "lines_to_next_cell": 0 }, @@ -1549,7 +1552,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28f9efd8", + "id": "e61d0c9b", "metadata": { "lines_to_next_cell": 0 }, @@ -1576,7 +1579,7 @@ }, { "cell_type": "markdown", - "id": "35eb9e2b", + "id": "6f1d3ff3", "metadata": { "lines_to_next_cell": 0 }, @@ -1590,7 +1593,7 @@ }, { "cell_type": "markdown", - "id": "b7e631b9", + "id": "90889399", "metadata": { "lines_to_next_cell": 0 }, @@ -1607,7 +1610,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7bf9f03", + "id": "f67b3f90", "metadata": {}, "outputs": [], "source": [ @@ -1630,7 +1633,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f2fa456", + "id": "b18b2b81", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1642,7 @@ }, { "cell_type": "markdown", - "id": "4c030783", + "id": "bf87e80b", "metadata": {}, "source": [ "

Questions

\n", @@ -1651,7 +1654,7 @@ }, { "cell_type": "markdown", - "id": "392618f7", + "id": "11aafcc5", "metadata": {}, "source": [ "

Checkpoint 5

\n", @@ -1669,7 +1672,7 @@ }, { "cell_type": "markdown", - "id": "609323f6", + "id": "a5c8b45e", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1684,7 +1687,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c69ea188", + "id": "45e17541", "metadata": { "tags": [ "solution"