diff --git a/exercise.ipynb b/exercise.ipynb index 52b0294..dbcc8ba 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c239177c", + "id": "9fad1fb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "192f7d95", + "id": "d2eb0ba6", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "41b78a2e", + "id": "66fd4eb4", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e83c46c", + "id": "8d0c5a17", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "269f9ace", + "id": "068a0ab7", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a02d8f0b", + "id": "a5706cea", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "07af7052", + "id": "9ae13dc9", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a3bec292", + "id": "61e909bb", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c50a466", + "id": "e06d760c", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "a4c8ed39", + "id": "be176cbc", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2d3c97d", + "id": "778c296c", "metadata": { "lines_to_next_cell": 2 }, @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "d0b6f156", + "id": "58e55138", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "17c741d3", + "id": "4ca35577", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95477885", + "id": "e18a3ae4", "metadata": { "tags": [] }, @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "350f4b0b", + "id": "aa4b2cb0", "metadata": { "tags": [] }, @@ -247,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0472f9c", + "id": "33463270", "metadata": { "tags": [ "task" @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5f3a16c", + "id": "8d0c7872", "metadata": { "tags": [] }, @@ -281,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "a0dff0c8", + "id": "f3e9270c", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -293,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08c82fb5", + "id": "425dbbcc", "metadata": { "tags": [] }, @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "141a0af8", + "id": "5f17d056", "metadata": { "tags": [] }, @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "9a3e5ebf", + "id": "fa8198ad", "metadata": { "lines_to_next_cell": 2 }, @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "014ff719", + "id": "564385db", "metadata": { "lines_to_next_cell": 0 }, @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70528897", + "id": "243d9f78", "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "ce932e89", + "id": "d74a9e52", "metadata": { "lines_to_next_cell": 0 }, @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "db289739", + "id": "a950ace4", "metadata": {}, "source": [ "\n", @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "c86ebdb5", + "id": "dbe69740", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -436,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db4cc099", + "id": "084ff537", "metadata": { "tags": [ "task" @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "8e295f7c", + "id": "2c0c6205", "metadata": { "tags": [] }, @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ba800fc", + "id": "1b06932e", "metadata": { "tags": [ "task" @@ -492,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "29600ea8", + "id": "15b67780", "metadata": { "tags": [] }, @@ -508,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "5e7b80d9", + "id": "46b17b7a", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -522,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "a76db362", + "id": "27e47ae9", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "e3ba74a0", + "id": "c7755d0d", "metadata": { "lines_to_next_cell": 0 }, @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "fe5fd2fc", + "id": "dd937252", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -593,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07417277", + "id": "2a3bb62c", "metadata": {}, "outputs": [], "source": [ @@ -625,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "7ee2ee22", + "id": "fc02905f", "metadata": { "lines_to_next_cell": 0 }, @@ -640,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6850fc29", + "id": "d81dccb8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -661,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "6ead6efc", + "id": "919cbcdf", "metadata": { "tags": [] }, @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "5da40b0a", + "id": "3515f790", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e9476d8e", + "id": "ef21e313", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -707,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "cc4e1d26", + "id": "825a5b81", "metadata": { "lines_to_next_cell": 0 }, @@ -718,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd3ce1ff", + "id": "7117cd7d", "metadata": {}, "outputs": [], "source": [ @@ -728,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "8e544341", + "id": "52182962", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -746,19 +746,19 @@ { "cell_type": "code", "execution_count": null, - "id": "2e18d801", + "id": "a084fbe2", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "2e41592e", + "id": "30c300ef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -777,7 +777,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7b18ce", + "id": "c74d359f", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "ecbf308f", + "id": "3cb1747c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -802,7 +802,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d4674ba", + "id": "29b973db", "metadata": {}, "outputs": [], "source": [ @@ -811,7 +811,7 @@ }, { "cell_type": "markdown", - "id": "d25ad125", + "id": "f5a2f065", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -823,7 +823,7 @@ { "cell_type": "code", "execution_count": null, - "id": "529dc669", + "id": "353b2412", "metadata": { "lines_to_next_cell": 1 }, @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "531b67c0", + "id": "ea495852", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -852,7 +852,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2125e6b8", + "id": "d0caae29", "metadata": { "lines_to_next_cell": 1 }, @@ -866,7 +866,7 @@ }, { "cell_type": "markdown", - "id": "a74270d4", + "id": "a2dc73d5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +886,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be244060", + "id": "5731e44c", "metadata": {}, "outputs": [], "source": [ @@ -910,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baefb71b", + "id": "faf83226", "metadata": {}, "outputs": [], "source": [ @@ -920,13 +920,13 @@ }, { "cell_type": "markdown", - "id": "d00ac9c3", + "id": "5ca6cb80", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", + "

Task 3.3: Training!

\n", "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", "Comment out the option that you think will not work.\n", "
    \n", @@ -942,7 +942,7 @@ }, { "cell_type": "markdown", - "id": "bbf1f4c3", + "id": "f540e9f9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -954,7 +954,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cebfa10", + "id": "c4ac820b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1065,7 +1065,7 @@ }, { "cell_type": "markdown", - "id": "adc5fe9c", + "id": "6de959c1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1077,7 +1077,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e9c6356", + "id": "ca459374", "metadata": {}, "outputs": [], "source": [ @@ -1090,7 +1090,7 @@ }, { "cell_type": "markdown", - "id": "b482c31e", + "id": "a04ada72", "metadata": { "tags": [] }, @@ -1105,7 +1105,7 @@ }, { "cell_type": "markdown", - "id": "1723a9bf", + "id": "18fb6fef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1117,7 +1117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c091294f", + "id": "17119e9f", "metadata": {}, "outputs": [], "source": [ @@ -1136,7 +1136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbf2d554", + "id": "f5c9a2db", "metadata": { "lines_to_next_cell": 0 }, @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "b84d8550", + "id": "8f1af03d", "metadata": { "tags": [] }, @@ -1161,17 +1161,17 @@ }, { "cell_type": "markdown", - "id": "18cbf21a", + "id": "605bf68c", "metadata": { "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "d6702bc6", + "id": "784f0d5d", "metadata": { "tags": [] }, @@ -1188,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1427a57c", + "id": "9307cba8", "metadata": { "title": "Loading the test dataset" }, @@ -1199,7 +1199,7 @@ "\n", "\n", "for i in range(4):\n", - " options = np.where(test_mnist.targets == i)[0]\n", + " options = np.where(test_mnist.conditions == i)[0]\n", " # Note that you can change the image index if you want to use a different prototype.\n", " image_index = 0\n", " x, y = test_mnist[options[image_index]]\n", @@ -1208,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "df29d400", + "id": "74473b00", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8a25419", + "id": "a9510356", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "67dee0fd", + "id": "249c45fb", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,12 +1243,12 @@ }, { "cell_type": "markdown", - "id": "081585ee", + "id": "dd0fb05f", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

    Task 4.1: Create counterfactuals

    \n", + "

    Task 4: Create counterfactuals

    \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
      \n", @@ -1261,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00e07e71", + "id": "64894033", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "2d5a8388", + "id": "716001cf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c345e081", + "id": "cc1239de", "metadata": {}, "outputs": [], "source": [ @@ -1319,7 +1319,7 @@ }, { "cell_type": "markdown", - "id": "669745a8", + "id": "9347c10b", "metadata": { "tags": [] }, @@ -1334,7 +1334,7 @@ }, { "cell_type": "markdown", - "id": "bb7e45fe", + "id": "f2233521", "metadata": { "tags": [] }, @@ -1345,7 +1345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88ce9154", + "id": "c7cdfd5f", "metadata": {}, "outputs": [], "source": [ @@ -1359,7 +1359,7 @@ }, { "cell_type": "markdown", - "id": "6533fc00", + "id": "a488e258", "metadata": { "tags": [] }, @@ -1374,15 +1374,7 @@ }, { "cell_type": "markdown", - "id": "782f049f", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "0b1ae3b2", + "id": "dec8dfbc", "metadata": { "lines_to_next_cell": 0 }, @@ -1397,7 +1389,7 @@ { "cell_type": "code", "execution_count": null, - "id": "006bf383", + "id": "9558f7b0", "metadata": { "lines_to_next_cell": 1 }, @@ -1419,7 +1411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6e2589b", + "id": "24103754", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1449,7 +1441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30aa3db1", + "id": "5b543a3c", "metadata": { "lines_to_next_cell": 0 }, @@ -1465,7 +1457,7 @@ }, { "cell_type": "markdown", - "id": "0c29c6b7", + "id": "42ffe1c6", "metadata": { "lines_to_next_cell": 0 }, @@ -1481,7 +1473,7 @@ }, { "cell_type": "markdown", - "id": "5f27f7e2", + "id": "8133616c", "metadata": { "lines_to_next_cell": 0 }, @@ -1496,12 +1488,12 @@ }, { "cell_type": "markdown", - "id": "49fca28b", + "id": "6477c0a4", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "# Part 6: Exploring the Style Space, finding the answer\n", + "# Part 5: Exploring the Style Space, finding the answer\n", "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", "\n", "Here is an example of two images that are very similar in color, but are of different classes.\n", @@ -1520,7 +1512,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bff81ec", + "id": "391c356d", "metadata": {}, "outputs": [], "source": [ @@ -1533,7 +1525,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8137940", + "id": "5c2761a6", "metadata": {}, "outputs": [], "source": [ @@ -1566,30 +1558,32 @@ }, { "cell_type": "markdown", - "id": "72af9914", + "id": "1a72be14", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

      Task 6.2: Adding color to the style space

      \n", + "

      Task 5.1: Adding color to the style space

      \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", "(Note: there is no code to write here, just run the cell and answer the questions below)\n", - "
      \n", - "TODO WIP HERE" + "
      " ] }, { "cell_type": "code", "execution_count": null, - "id": "777414b4", + "id": "624d7e7e", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", @@ -1603,7 +1597,7 @@ }, { "cell_type": "markdown", - "id": "a15bc698", + "id": "4168872c", "metadata": { "lines_to_next_cell": 0 }, @@ -1617,14 +1611,14 @@ }, { "cell_type": "markdown", - "id": "bb6dd36e", + "id": "f0e8ce5e", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

      Using the images to color the style space

      \n", + "

      Task 5.2: Using the images to color the style space

      \n", "Finally, let's just use the colors from the images themselves!\n", - "All of the non-zero values in the image can be averaged to get a color.\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", "\n", "Let's get that color, then plot the style space again.\n", "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", @@ -1634,7 +1628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6b6d2c2", + "id": "98d61014", "metadata": { "lines_to_next_cell": 0 }, @@ -1643,7 +1637,7 @@ }, { "cell_type": "markdown", - "id": "fe266bcb", + "id": "9baf1cbb", "metadata": {}, "source": [ "

      Questions

      \n", @@ -1655,7 +1649,7 @@ }, { "cell_type": "markdown", - "id": "c2f3aff5", + "id": "9e9b79ba", "metadata": {}, "source": [ "

      Checkpoint 5

      \n", diff --git a/solution.ipynb b/solution.ipynb index c52377b..2087e90 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c239177c", + "id": "9fad1fb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "192f7d95", + "id": "d2eb0ba6", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "41b78a2e", + "id": "66fd4eb4", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e83c46c", + "id": "8d0c5a17", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "269f9ace", + "id": "068a0ab7", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a02d8f0b", + "id": "a5706cea", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "07af7052", + "id": "9ae13dc9", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a3bec292", + "id": "61e909bb", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f0c2c03", + "id": "9f351427", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "a4c8ed39", + "id": "be176cbc", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2d3c97d", + "id": "778c296c", "metadata": { "lines_to_next_cell": 2 }, @@ -191,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "d0b6f156", + "id": "58e55138", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -201,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "17c741d3", + "id": "4ca35577", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -214,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95477885", + "id": "e18a3ae4", "metadata": { "tags": [] }, @@ -230,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "350f4b0b", + "id": "aa4b2cb0", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a8eac43", + "id": "cdcbfa60", "metadata": { "tags": [ "solution" @@ -270,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5f3a16c", + "id": "8d0c7872", "metadata": { "tags": [] }, @@ -283,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "a0dff0c8", + "id": "f3e9270c", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -295,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08c82fb5", + "id": "425dbbcc", "metadata": { "tags": [] }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "141a0af8", + "id": "5f17d056", "metadata": { "tags": [] }, @@ -335,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "9a3e5ebf", + "id": "fa8198ad", "metadata": { "lines_to_next_cell": 2 }, @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "014ff719", + "id": "564385db", "metadata": { "lines_to_next_cell": 0 }, @@ -362,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70528897", + "id": "243d9f78", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "ce932e89", + "id": "d74a9e52", "metadata": { "lines_to_next_cell": 0 }, @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "db289739", + "id": "a950ace4", "metadata": {}, "source": [ "\n", @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "c86ebdb5", + "id": "dbe69740", "metadata": {}, "source": [ "

      Task 2.3: Use random noise as a baseline

      \n", @@ -438,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "63c5a503", + "id": "e5710918", "metadata": { "tags": [ "solution" @@ -463,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "8e295f7c", + "id": "2c0c6205", "metadata": { "tags": [] }, @@ -477,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ebfaae1", + "id": "1281007a", "metadata": { "tags": [ "solution" @@ -504,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "29600ea8", + "id": "15b67780", "metadata": { "tags": [] }, @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "5e7b80d9", + "id": "46b17b7a", "metadata": {}, "source": [ "

      BONUS Task: Using different attributions.

      \n", @@ -534,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "a76db362", + "id": "27e47ae9", "metadata": {}, "source": [ "

      Checkpoint 2

      \n", @@ -554,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "e3ba74a0", + "id": "c7755d0d", "metadata": { "lines_to_next_cell": 0 }, @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "fe5fd2fc", + "id": "dd937252", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -605,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07417277", + "id": "2a3bb62c", "metadata": {}, "outputs": [], "source": [ @@ -637,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "7ee2ee22", + "id": "fc02905f", "metadata": { "lines_to_next_cell": 0 }, @@ -652,7 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6454d2e9", + "id": "6196dc49", "metadata": { "tags": [ "solution" @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "6ead6efc", + "id": "919cbcdf", "metadata": { "tags": [] }, @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "5da40b0a", + "id": "3515f790", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -701,7 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "927e677b", + "id": "28c68855", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "cc4e1d26", + "id": "825a5b81", "metadata": { "lines_to_next_cell": 0 }, @@ -726,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd3ce1ff", + "id": "7117cd7d", "metadata": {}, "outputs": [], "source": [ @@ -736,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "8e544341", + "id": "52182962", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -754,19 +754,19 @@ { "cell_type": "code", "execution_count": null, - "id": "2e18d801", + "id": "a084fbe2", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-5)\n", "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "2e41592e", + "id": "30c300ef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -785,7 +785,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa7b18ce", + "id": "c74d359f", "metadata": {}, "outputs": [], "source": [ @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "ecbf308f", + "id": "3cb1747c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -810,7 +810,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d4674ba", + "id": "29b973db", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "d25ad125", + "id": "f5a2f065", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -831,7 +831,7 @@ { "cell_type": "code", "execution_count": null, - "id": "529dc669", + "id": "353b2412", "metadata": { "lines_to_next_cell": 1 }, @@ -846,7 +846,7 @@ }, { "cell_type": "markdown", - "id": "531b67c0", + "id": "ea495852", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -860,7 +860,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2125e6b8", + "id": "d0caae29", "metadata": { "lines_to_next_cell": 1 }, @@ -874,7 +874,7 @@ }, { "cell_type": "markdown", - "id": "a74270d4", + "id": "a2dc73d5", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -894,7 +894,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be244060", + "id": "5731e44c", "metadata": {}, "outputs": [], "source": [ @@ -918,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baefb71b", + "id": "faf83226", "metadata": {}, "outputs": [], "source": [ @@ -928,13 +928,13 @@ }, { "cell_type": "markdown", - "id": "d00ac9c3", + "id": "5ca6cb80", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "

      Task 3.2: Training!

      \n", + "

      Task 3.3: Training!

      \n", "You were given several different options in the training code below. In each case, one of the options will work, and the other will not.\n", "Comment out the option that you think will not work.\n", "
        \n", @@ -950,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "bbf1f4c3", + "id": "f540e9f9", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -962,7 +962,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0cb9ac26", + "id": "abb9371f", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1032,7 +1032,7 @@ }, { "cell_type": "markdown", - "id": "adc5fe9c", + "id": "6de959c1", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1044,7 +1044,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e9c6356", + "id": "ca459374", "metadata": {}, "outputs": [], "source": [ @@ -1057,7 +1057,7 @@ }, { "cell_type": "markdown", - "id": "b482c31e", + "id": "a04ada72", "metadata": { "tags": [] }, @@ -1072,7 +1072,7 @@ }, { "cell_type": "markdown", - "id": "1723a9bf", + "id": "18fb6fef", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1084,7 +1084,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c091294f", + "id": "17119e9f", "metadata": {}, "outputs": [], "source": [ @@ -1103,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cbf2d554", + "id": "f5c9a2db", "metadata": { "lines_to_next_cell": 0 }, @@ -1112,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "b84d8550", + "id": "8f1af03d", "metadata": { "tags": [] }, @@ -1128,17 +1128,17 @@ }, { "cell_type": "markdown", - "id": "18cbf21a", + "id": "605bf68c", "metadata": { "tags": [] }, "source": [ - "# Part 4: Evaluating the GAN" + "# Part 4: Evaluating the GAN and creating Counterfactuals" ] }, { "cell_type": "markdown", - "id": "d6702bc6", + "id": "784f0d5d", "metadata": { "tags": [] }, @@ -1155,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1427a57c", + "id": "9307cba8", "metadata": { "title": "Loading the test dataset" }, @@ -1166,7 +1166,7 @@ "\n", "\n", "for i in range(4):\n", - " options = np.where(test_mnist.targets == i)[0]\n", + " options = np.where(test_mnist.conditions == i)[0]\n", " # Note that you can change the image index if you want to use a different prototype.\n", " image_index = 0\n", " x, y = test_mnist[options[image_index]]\n", @@ -1175,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "df29d400", + "id": "74473b00", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8a25419", + "id": "a9510356", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "67dee0fd", + "id": "249c45fb", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,12 +1210,12 @@ }, { "cell_type": "markdown", - "id": "081585ee", + "id": "dd0fb05f", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

        Task 4.1: Create counterfactuals

        \n", + "

        Task 4: Create counterfactuals

        \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
          \n", @@ -1228,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bbffb9f", + "id": "99ecfc15", "metadata": { "tags": [ "solution" @@ -1265,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "2d5a8388", + "id": "716001cf", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c345e081", + "id": "cc1239de", "metadata": {}, "outputs": [], "source": [ @@ -1287,7 +1287,7 @@ }, { "cell_type": "markdown", - "id": "669745a8", + "id": "9347c10b", "metadata": { "tags": [] }, @@ -1302,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "bb7e45fe", + "id": "f2233521", "metadata": { "tags": [] }, @@ -1313,7 +1313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88ce9154", + "id": "c7cdfd5f", "metadata": {}, "outputs": [], "source": [ @@ -1327,7 +1327,7 @@ }, { "cell_type": "markdown", - "id": "6533fc00", + "id": "a488e258", "metadata": { "tags": [] }, @@ -1342,15 +1342,7 @@ }, { "cell_type": "markdown", - "id": "782f049f", - "metadata": {}, - "source": [ - "# Part 5: Highlighting Class-Relevant Differences" - ] - }, - { - "cell_type": "markdown", - "id": "0b1ae3b2", + "id": "dec8dfbc", "metadata": { "lines_to_next_cell": 0 }, @@ -1365,7 +1357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "006bf383", + "id": "9558f7b0", "metadata": { "lines_to_next_cell": 1 }, @@ -1387,7 +1379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6e2589b", + "id": "24103754", "metadata": { "lines_to_next_cell": 1, "title": "Another visualization function" @@ -1417,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30aa3db1", + "id": "5b543a3c", "metadata": { "lines_to_next_cell": 0 }, @@ -1433,7 +1425,7 @@ }, { "cell_type": "markdown", - "id": "0c29c6b7", + "id": "42ffe1c6", "metadata": { "lines_to_next_cell": 0 }, @@ -1449,7 +1441,7 @@ }, { "cell_type": "markdown", - "id": "5f27f7e2", + "id": "8133616c", "metadata": { "lines_to_next_cell": 0 }, @@ -1464,12 +1456,12 @@ }, { "cell_type": "markdown", - "id": "49fca28b", + "id": "6477c0a4", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "# Part 6: Exploring the Style Space, finding the answer\n", + "# Part 5: Exploring the Style Space, finding the answer\n", "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", "\n", "Here is an example of two images that are very similar in color, but are of different classes.\n", @@ -1488,7 +1480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bff81ec", + "id": "391c356d", "metadata": {}, "outputs": [], "source": [ @@ -1501,7 +1493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8137940", + "id": "5c2761a6", "metadata": {}, "outputs": [], "source": [ @@ -1534,30 +1526,32 @@ }, { "cell_type": "markdown", - "id": "72af9914", + "id": "1a72be14", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

          Task 6.2: Adding color to the style space

          \n", + "

          Task 5.1: Adding color to the style space

          \n", "We know that color is important. Does interpreting the style space as colors help us understand better?\n", "\n", "Let's use the style space to color the PCA plot.\n", "(Note: there is no code to write here, just run the cell and answer the questions below)\n", - "
          \n", - "TODO WIP HERE" + "
          " ] }, { "cell_type": "code", "execution_count": null, - "id": "777414b4", + "id": "624d7e7e", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "normalized_styles = (styles - np.min(styles, axis=1)) / styles.ptp(axis=1)\n", + "styles = np.array(styles)\n", + "normalized_styles = (styles - np.min(styles, axis=1, keepdims=True)) / np.ptp(\n", + " styles, axis=1, keepdims=True\n", + ")\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", @@ -1571,7 +1565,7 @@ }, { "cell_type": "markdown", - "id": "a15bc698", + "id": "4168872c", "metadata": { "lines_to_next_cell": 0 }, @@ -1585,14 +1579,14 @@ }, { "cell_type": "markdown", - "id": "bb6dd36e", + "id": "f0e8ce5e", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

          Using the images to color the style space

          \n", + "

          Task 5.2: Using the images to color the style space

          \n", "Finally, let's just use the colors from the images themselves!\n", - "All of the non-zero values in the image can be averaged to get a color.\n", + "The maximum value in the image (since they are \"black-and-color\") can be used as a color!\n", "\n", "Let's get that color, then plot the style space again.\n", "(Note: once again, no coding needed here, just run the cell and think about the results with the questions below)\n", @@ -1602,7 +1596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f17c1af", + "id": "75f470fb", "metadata": { "tags": [ "solution" @@ -1610,19 +1604,14 @@ }, "outputs": [], "source": [ - "tol = 1e-6\n", - "\n", - "colors = []\n", - "for x, y in random_test_mnist:\n", - " non_zero = x[x > tol]\n", - " colors.append(non_zero.mean(dim=(1, 2)).cpu().numpy().squeeze())\n", + "colors = [np.max(x.numpy(), axis=(1, 2)) for x, _ in random_test_mnist]\n", "\n", "# Plot the PCA again!\n", "plt.figure(figsize=(10, 10))\n", "plt.scatter(\n", " styles_pca[:, 0],\n", " styles_pca[:, 1],\n", - " c=normalized_styles,\n", + " c=colors,\n", ")\n", "plt.show()" ] @@ -1630,7 +1619,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6b6d2c2", + "id": "98d61014", "metadata": { "lines_to_next_cell": 0 }, @@ -1639,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "fe266bcb", + "id": "9baf1cbb", "metadata": {}, "source": [ "

          Questions

          \n", @@ -1651,7 +1640,7 @@ }, { "cell_type": "markdown", - "id": "c2f3aff5", + "id": "9e9b79ba", "metadata": {}, "source": [ "

          Checkpoint 5

          \n", @@ -1669,7 +1658,7 @@ }, { "cell_type": "markdown", - "id": "c3c83fa2", + "id": "ba5fab31", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1684,7 +1673,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5594b649", + "id": "00729fac", "metadata": { "tags": [ "solution"