diff --git a/exercise.ipynb b/exercise.ipynb index 8f802ba..41bfda6 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e998cbda", + "id": "b3ddb066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "f3b46176", + "id": "7f43c1e3", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "b0ad2695", + "id": "a9aaf840", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "774d942d", + "id": "3f6c5bc0", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "32c74ae3", + "id": "5dd19fe5", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e2bfb78", + "id": "8c709838", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "2e368025", + "id": "6b04f969", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "b4ba9ba1", + "id": "27ea9906", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ecc51041", + "id": "3bf64cda", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "358f92e4", + "id": "2ad014ac", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -165,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "23375b54", + "id": "e40eeba7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -178,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bc95c12", + "id": "e7aca710", "metadata": { "tags": [] }, @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "ce061847", + "id": "44d286aa", "metadata": { "tags": [] }, @@ -210,7 +210,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fe7d564", + "id": "ec387f82", "metadata": { "tags": [ "task" @@ -231,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3b8fada", + "id": "8ea56240", "metadata": { "tags": [] }, @@ -244,7 +244,7 @@ }, { "cell_type": "markdown", - "id": "1749ba9c", + "id": "bc50850e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -256,7 +256,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b11c4963", + "id": "e7447933", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4a2e4ba", + "id": "5cb527bf", "metadata": { "tags": [] }, @@ -296,7 +296,7 @@ }, { "cell_type": "markdown", - "id": "35dbc255", + "id": "25ecec3e", "metadata": { "lines_to_next_cell": 2 }, @@ -310,7 +310,7 @@ }, { "cell_type": "markdown", - "id": "cb45a3b7", + "id": "2b43f05d", "metadata": { "lines_to_next_cell": 0 }, @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9fe579e8", + "id": "a7b21894", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "6db49a33", + "id": "f97eace2", "metadata": { "lines_to_next_cell": 0 }, @@ -361,7 +361,7 @@ }, { "cell_type": "markdown", - "id": "68c48063", + "id": "ed5b7d6e", "metadata": {}, "source": [ "\n", @@ -387,7 +387,7 @@ }, { "cell_type": "markdown", - "id": "b4f45692", + "id": "bf13ae8d", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -399,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00a40c0c", + "id": "60d6691c", "metadata": { "tags": [ "task" @@ -419,7 +419,7 @@ }, { "cell_type": "markdown", - "id": "24db5ea4", + "id": "f24c00a3", "metadata": { "tags": [] }, @@ -433,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01485873", + "id": "3835de1a", "metadata": { "tags": [ "task" @@ -455,7 +455,7 @@ }, { "cell_type": "markdown", - "id": "341fe9b8", + "id": "10a6cfcc", "metadata": { "tags": [] }, @@ -471,7 +471,7 @@ }, { "cell_type": "markdown", - "id": "0b0e6145", + "id": "25f3d08e", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -485,7 +485,7 @@ }, { "cell_type": "markdown", - "id": "0f67562c", + "id": "65d946a8", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -505,7 +505,7 @@ }, { "cell_type": "markdown", - "id": "003fed33", + "id": "04602cf9", "metadata": { "lines_to_next_cell": 0 }, @@ -533,22 +533,22 @@ }, { "cell_type": "markdown", - "id": "1c99e326", + "id": "ed173d7c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "### The model\n", - "![cycle.png](assets/cyclegan.png)\n", + "![stargan.png](assets/stargan.png)\n", "\n", "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "The model is made up of three networks:\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", - "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", "Let's start by creating these!" ] @@ -556,10 +556,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3fa2a39a", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "03a51bad", + "metadata": {}, "outputs": [], "source": [ "from dlmbl_unet import UNet\n", @@ -567,10 +565,11 @@ "\n", "\n", "class Generator(nn.Module):\n", - " def __init__(self, generator, style_mapping):\n", + "\n", + " def __init__(self, generator, style_encoder):\n", " super().__init__()\n", " self.generator = generator\n", - " self.style_mapping = style_mapping\n", + " self.style_encoder = style_encoder\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", @@ -579,7 +578,7 @@ " y: torch.Tensor\n", " The style image\n", " \"\"\"\n", - " style = self.style_mapping(y)\n", + " style = self.style_encoder(y)\n", " # Concatenate the style vector with the input image\n", " style = style.unsqueeze(-1).unsqueeze(-1)\n", " style = style.expand(-1, -1, x.size(2), x.size(3))\n", @@ -589,7 +588,7 @@ }, { "cell_type": "markdown", - "id": "11c69ace", + "id": "e6c6168d", "metadata": { "lines_to_next_cell": 0 }, @@ -604,12 +603,14 @@ { "cell_type": "code", "execution_count": null, - "id": "734e1e36", + "id": "9d0ef49f", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ + "style_size = ... # TODO choose a size for the style space\n", + "unet_depth = ... # TODO Choose a depth for the UNet\n", "style_mapping = DenseModel(\n", " input_shape=..., num_classes=... # How big is the style space?\n", ")\n", @@ -620,7 +621,22 @@ }, { "cell_type": "markdown", - "id": "74b2fe60", + "id": "bd761ef3", + "metadata": { + "tags": [] + }, + "source": [ + "

Hyper-parameter choices

\n", + "
    \n", + "
  • Are any of the hyperparameters you choose above constrained in some way?
  • \n", + "
  • What would happen if you chose a depth of 10 for the UNet?
  • \n", + "
  • Is there a minimum size for the style space? Why or why not?
  • \n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "d1220bb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -637,7 +653,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4416d6eb", + "id": "71482197", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -649,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "b20d0919", + "id": "709affba", "metadata": { "lines_to_next_cell": 0 }, @@ -660,7 +676,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bc98d13", + "id": "7059545e", "metadata": {}, "outputs": [], "source": [ @@ -670,24 +686,89 @@ }, { "cell_type": "markdown", - "id": "2cc4a339", + "id": "b1a7581c", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "## Training a GAN\n", "\n", - "Yes, really!\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", + "\n", + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7805887e", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "1bad28d8", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
\n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a757512e", + "metadata": {}, + "outputs": [], + "source": [ + "adverial_loss_fn = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "5c590737", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n", "\n", - "TODO about the losses:\n", - "- An adversarial loss\n", - "- A cycle loss\n", - "TODO add exercise!" + "cycle_loss_fn = nn.L1Loss()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0def44d4", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "87761838", + "id": "3a0c1d2e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -702,16 +783,25 @@ }, { "cell_type": "markdown", - "id": "bcc737d6", + "id": "9f577571", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "...this time again.\n", "\n", - "\"drawing\"\n", - "\n", - "TODO also turn this into a standalong script for use during the project phase\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3077e49", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO also turn this into a standalone script for use during the project phase\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", @@ -788,7 +878,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86957c62", + "id": "b232bd07", "metadata": { "lines_to_next_cell": 0 }, @@ -803,7 +893,7 @@ }, { "cell_type": "markdown", - "id": "efd44cf5", + "id": "16de7380", "metadata": { "tags": [] }, @@ -814,7 +904,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22c3f513", + "id": "856af9da", "metadata": {}, "outputs": [], "source": [ @@ -834,7 +924,7 @@ }, { "cell_type": "markdown", - "id": "87b45015", + "id": "f7240ca5", "metadata": { "tags": [] }, @@ -850,7 +940,7 @@ }, { "cell_type": "markdown", - "id": "d4e7a929", + "id": "67168867", "metadata": { "tags": [] }, @@ -860,7 +950,7 @@ }, { "cell_type": "markdown", - "id": "7d02cc75", + "id": "c6bdbfde", "metadata": { "tags": [] }, @@ -876,7 +966,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a539070f", + "id": "a8543304", "metadata": { "tags": [] }, @@ -890,7 +980,7 @@ }, { "cell_type": "markdown", - "id": "d1b2507b", + "id": "940b48d6", "metadata": { "tags": [] }, @@ -901,7 +991,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2ab6b33", + "id": "8b9425d2", "metadata": { "tags": [] }, @@ -912,7 +1002,7 @@ }, { "cell_type": "markdown", - "id": "7de66a63", + "id": "42f81f13", "metadata": { "tags": [] }, @@ -923,7 +1013,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fcc912a", + "id": "33fbfc83", "metadata": { "tags": [] }, @@ -934,7 +1024,7 @@ }, { "cell_type": "markdown", - "id": "929e292b", + "id": "00ded88d", "metadata": { "tags": [] }, @@ -947,7 +1037,7 @@ }, { "cell_type": "markdown", - "id": "7abe7429", + "id": "f7475dc3", "metadata": { "tags": [] }, @@ -968,7 +1058,7 @@ }, { "cell_type": "markdown", - "id": "55bb626d", + "id": "97a88ddb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -993,7 +1083,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67390c1b", + "id": "2f82fa67", "metadata": {}, "outputs": [], "source": [ @@ -1004,7 +1094,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2930d6cd", + "id": "b93db0b2", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1017,7 +1107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0ae9923", + "id": "5c7ccc7b", "metadata": {}, "outputs": [], "source": [ @@ -1030,7 +1120,7 @@ }, { "cell_type": "markdown", - "id": "5d2739a2", + "id": "d47955f7", "metadata": { "tags": [] }, @@ -1041,7 +1131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "933e724b", + "id": "94284732", "metadata": { "lines_to_next_cell": 0 }, @@ -1056,7 +1146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8367d7e7", + "id": "cb6f9edc", "metadata": {}, "outputs": [], "source": [ @@ -1068,7 +1158,7 @@ }, { "cell_type": "markdown", - "id": "28279f41", + "id": "8aba5707", "metadata": {}, "source": [ "
\n", @@ -1083,7 +1173,7 @@ }, { "cell_type": "markdown", - "id": "db7e8748", + "id": "b9713122", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1096,7 +1186,7 @@ }, { "cell_type": "markdown", - "id": "ca69811f", + "id": "183344be", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1104,7 +1194,7 @@ }, { "cell_type": "markdown", - "id": "3a84225c", + "id": "83417bff", "metadata": {}, "source": [ "At this point we have:\n", @@ -1119,7 +1209,7 @@ }, { "cell_type": "markdown", - "id": "fd9cd294", + "id": "737ae577", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1140,7 +1230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb49e17c", + "id": "84c56d18", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1168,7 +1258,7 @@ }, { "cell_type": "markdown", - "id": "df1543ab", + "id": "8737c833", "metadata": { "tags": [] }, @@ -1179,7 +1269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31a46e04", + "id": "ee8f6090", "metadata": { "tags": [] }, @@ -1191,7 +1281,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e54ae384", + "id": "b33a0107", "metadata": { "tags": [] }, @@ -1212,7 +1302,7 @@ }, { "cell_type": "markdown", - "id": "ccbc04c1", + "id": "2edae8d4", "metadata": { "tags": [] }, @@ -1228,7 +1318,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53050f11", + "id": "79d46ed5", "metadata": { "tags": [] }, @@ -1241,7 +1331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c71cb0f8", + "id": "0ec9b3cf", "metadata": { "tags": [] }, @@ -1272,7 +1362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76caab37", + "id": "c387ba61", "metadata": {}, "outputs": [], "source": [] @@ -1280,7 +1370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35991baf", + "id": "8b9e843e", "metadata": { "tags": [] }, @@ -1361,7 +1451,7 @@ }, { "cell_type": "markdown", - "id": "a270e2d8", + "id": "837d2a6a", "metadata": { "tags": [] }, @@ -1377,7 +1467,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34e7801c", + "id": "01f878a8", "metadata": { "tags": [] }, @@ -1388,7 +1478,7 @@ }, { "cell_type": "markdown", - "id": "e4009e6f", + "id": "28aceac4", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1397,7 +1487,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7adaa4d4", + "id": "2ae84d44", "metadata": { "tags": [] }, @@ -1413,7 +1503,7 @@ }, { "cell_type": "markdown", - "id": "33544547", + "id": "8ff5ceb0", "metadata": { "tags": [] }, @@ -1431,7 +1521,7 @@ }, { "cell_type": "markdown", - "id": "4ed9c11a", + "id": "ca976c6b", "metadata": { "tags": [] }, @@ -1444,7 +1534,7 @@ }, { "cell_type": "markdown", - "id": "7a1577b8", + "id": "bd96b144", "metadata": { "tags": [] }, diff --git a/solution.ipynb b/solution.ipynb index ee650be..231e6f7 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e998cbda", + "id": "b3ddb066", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "f3b46176", + "id": "7f43c1e3", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "b0ad2695", + "id": "a9aaf840", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "774d942d", + "id": "3f6c5bc0", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "32c74ae3", + "id": "5dd19fe5", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e2bfb78", + "id": "8c709838", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "2e368025", + "id": "6b04f969", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "b4ba9ba1", + "id": "27ea9906", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bcfac6b2", + "id": "6457422b", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "358f92e4", + "id": "2ad014ac", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -164,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "23375b54", + "id": "e40eeba7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bc95c12", + "id": "e7aca710", "metadata": { "tags": [] }, @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "ce061847", + "id": "44d286aa", "metadata": { "tags": [] }, @@ -209,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "56f04f69", + "id": "55d4cbcc", "metadata": { "tags": [ "solution" @@ -233,7 +233,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3b8fada", + "id": "8ea56240", "metadata": { "tags": [] }, @@ -246,7 +246,7 @@ }, { "cell_type": "markdown", - "id": "1749ba9c", + "id": "bc50850e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -258,7 +258,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b11c4963", + "id": "e7447933", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4a2e4ba", + "id": "5cb527bf", "metadata": { "tags": [] }, @@ -298,7 +298,7 @@ }, { "cell_type": "markdown", - "id": "35dbc255", + "id": "25ecec3e", "metadata": { "lines_to_next_cell": 2 }, @@ -312,7 +312,7 @@ }, { "cell_type": "markdown", - "id": "cb45a3b7", + "id": "2b43f05d", "metadata": { "lines_to_next_cell": 0 }, @@ -325,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9fe579e8", + "id": "a7b21894", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "6db49a33", + "id": "f97eace2", "metadata": { "lines_to_next_cell": 0 }, @@ -363,7 +363,7 @@ }, { "cell_type": "markdown", - "id": "68c48063", + "id": "ed5b7d6e", "metadata": {}, "source": [ "\n", @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "b4f45692", + "id": "bf13ae8d", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -401,7 +401,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c11ff6ef", + "id": "6e85e3e4", "metadata": { "tags": [ "solution" @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "24db5ea4", + "id": "f24c00a3", "metadata": { "tags": [] }, @@ -440,7 +440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "428f4870", + "id": "12743143", "metadata": { "tags": [ "solution" @@ -467,7 +467,7 @@ }, { "cell_type": "markdown", - "id": "341fe9b8", + "id": "10a6cfcc", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ }, { "cell_type": "markdown", - "id": "0b0e6145", + "id": "25f3d08e", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -497,7 +497,7 @@ }, { "cell_type": "markdown", - "id": "0f67562c", + "id": "65d946a8", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -517,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "003fed33", + "id": "04602cf9", "metadata": { "lines_to_next_cell": 0 }, @@ -545,22 +545,22 @@ }, { "cell_type": "markdown", - "id": "1c99e326", + "id": "ed173d7c", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ "### The model\n", - "![cycle.png](assets/cyclegan.png)\n", + "![stargan.png](assets/stargan.png)\n", "\n", "In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).\n", "It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.\n", "\n", - "The model is made up of three networks:\n", + "We will not be using the random latent code (green, in the figure), so the model we use is made up of three networks:\n", "- The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`\n", "- The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`\n", - "- The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", + "- The style encoder - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`\n", "\n", "Let's start by creating these!" ] @@ -568,10 +568,8 @@ { "cell_type": "code", "execution_count": null, - "id": "3fa2a39a", - "metadata": { - "lines_to_next_cell": 1 - }, + "id": "03a51bad", + "metadata": {}, "outputs": [], "source": [ "from dlmbl_unet import UNet\n", @@ -579,10 +577,11 @@ "\n", "\n", "class Generator(nn.Module):\n", - " def __init__(self, generator, style_mapping):\n", + "\n", + " def __init__(self, generator, style_encoder):\n", " super().__init__()\n", " self.generator = generator\n", - " self.style_mapping = style_mapping\n", + " self.style_encoder = style_encoder\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", @@ -591,7 +590,7 @@ " y: torch.Tensor\n", " The style image\n", " \"\"\"\n", - " style = self.style_mapping(y)\n", + " style = self.style_encoder(y)\n", " # Concatenate the style vector with the input image\n", " style = style.unsqueeze(-1).unsqueeze(-1)\n", " style = style.expand(-1, -1, x.size(2), x.size(3))\n", @@ -601,7 +600,7 @@ }, { "cell_type": "markdown", - "id": "11c69ace", + "id": "e6c6168d", "metadata": { "lines_to_next_cell": 0 }, @@ -616,12 +615,14 @@ { "cell_type": "code", "execution_count": null, - "id": "734e1e36", + "id": "9d0ef49f", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ + "style_size = ... # TODO choose a size for the style space\n", + "unet_depth = ... # TODO Choose a depth for the UNet\n", "style_mapping = DenseModel(\n", " input_shape=..., num_classes=... # How big is the style space?\n", ")\n", @@ -633,7 +634,7 @@ { "cell_type": "code", "execution_count": null, - "id": "347455b7", + "id": "ff22f753", "metadata": { "tags": [ "solution" @@ -641,15 +642,31 @@ }, "outputs": [], "source": [ - "# Here is an example of a working exercise\n", - "style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", + "# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment.\n", + "# Choose your own setup to see what works for you.\n", + "style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3)\n", "unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())\n", - "generator = Generator(unet, style_mapping=style_mapping)" + "generator = Generator(unet, style_encoder=style_encoder)" + ] + }, + { + "cell_type": "markdown", + "id": "bd761ef3", + "metadata": { + "tags": [] + }, + "source": [ + "

Hyper-parameter choices

\n", + "
    \n", + "
  • Are any of the hyperparameters you choose above constrained in some way?
  • \n", + "
  • What would happen if you chose a depth of 10 for the UNet?
  • \n", + "
  • Is there a minimum size for the style space? Why or why not?
  • \n", + "
" ] }, { "cell_type": "markdown", - "id": "74b2fe60", + "id": "d1220bb6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -666,7 +683,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4416d6eb", + "id": "71482197", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -679,7 +696,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a3291bf", + "id": "7ef652d9", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -693,7 +710,7 @@ }, { "cell_type": "markdown", - "id": "b20d0919", + "id": "709affba", "metadata": { "lines_to_next_cell": 0 }, @@ -704,7 +721,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bc98d13", + "id": "7059545e", "metadata": {}, "outputs": [], "source": [ @@ -714,24 +731,89 @@ }, { "cell_type": "markdown", - "id": "2cc4a339", + "id": "b1a7581c", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "## Training a GAN\n", "\n", - "Yes, really!\n", + "Training an adversarial network is a bit more complicated than training a classifier.\n", + "For starters, we are simultaneously training two different networks that work against each other.\n", + "As such, we need to be careful about how and when we update the weights of each network.\n", "\n", - "TODO about the losses:\n", - "- An adversarial loss\n", - "- A cycle loss\n", - "TODO add exercise!" + "We will have two different optimizers, one for the Generator and one for the Discriminator.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7805887e", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)\n", + "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)" ] }, { "cell_type": "markdown", - "id": "87761838", + "id": "1bad28d8", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "\n", + "There are also two different types of losses that we will need.\n", + "**Adversarial loss**\n", + "This loss describes how well the discriminator can tell the difference between real and generated images.\n", + "In our case, this will be a sort of classification loss - we will use Cross Entropy.\n", + "
\n", + "The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a757512e", + "metadata": {}, + "outputs": [], + "source": [ + "adverial_loss_fn = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "5c590737", + "metadata": { + "tags": [] + }, + "source": [ + "\n", + "**Cycle/reconstruction loss**\n", + "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", + "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", + "The cycle loss is applied only to the generator.\n", + "\n", + "cycle_loss_fn = nn.L1Loss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0def44d4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "3a0c1d2e", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -746,16 +828,25 @@ }, { "cell_type": "markdown", - "id": "bcc737d6", + "id": "9f577571", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ "...this time again.\n", "\n", - "\"drawing\"\n", - "\n", - "TODO also turn this into a standalong script for use during the project phase\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3077e49", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO also turn this into a standalone script for use during the project phase\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", @@ -832,7 +923,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86957c62", + "id": "b232bd07", "metadata": { "lines_to_next_cell": 0 }, @@ -847,7 +938,7 @@ }, { "cell_type": "markdown", - "id": "efd44cf5", + "id": "16de7380", "metadata": { "tags": [] }, @@ -858,7 +949,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22c3f513", + "id": "856af9da", "metadata": {}, "outputs": [], "source": [ @@ -878,7 +969,7 @@ }, { "cell_type": "markdown", - "id": "87b45015", + "id": "f7240ca5", "metadata": { "tags": [] }, @@ -894,7 +985,7 @@ }, { "cell_type": "markdown", - "id": "d4e7a929", + "id": "67168867", "metadata": { "tags": [] }, @@ -904,7 +995,7 @@ }, { "cell_type": "markdown", - "id": "7d02cc75", + "id": "c6bdbfde", "metadata": { "tags": [] }, @@ -920,7 +1011,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a539070f", + "id": "a8543304", "metadata": { "tags": [] }, @@ -934,7 +1025,7 @@ }, { "cell_type": "markdown", - "id": "d1b2507b", + "id": "940b48d6", "metadata": { "tags": [] }, @@ -945,7 +1036,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2ab6b33", + "id": "8b9425d2", "metadata": { "tags": [] }, @@ -956,7 +1047,7 @@ }, { "cell_type": "markdown", - "id": "7de66a63", + "id": "42f81f13", "metadata": { "tags": [] }, @@ -967,7 +1058,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6fcc912a", + "id": "33fbfc83", "metadata": { "tags": [] }, @@ -978,7 +1069,7 @@ }, { "cell_type": "markdown", - "id": "929e292b", + "id": "00ded88d", "metadata": { "tags": [] }, @@ -991,7 +1082,7 @@ }, { "cell_type": "markdown", - "id": "7abe7429", + "id": "f7475dc3", "metadata": { "tags": [] }, @@ -1012,7 +1103,7 @@ }, { "cell_type": "markdown", - "id": "55bb626d", + "id": "97a88ddb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1037,7 +1128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67390c1b", + "id": "2f82fa67", "metadata": {}, "outputs": [], "source": [ @@ -1048,7 +1139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2930d6cd", + "id": "b93db0b2", "metadata": { "lines_to_next_cell": 0, "title": "[markwodn]" @@ -1061,7 +1152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0ae9923", + "id": "5c7ccc7b", "metadata": {}, "outputs": [], "source": [ @@ -1074,7 +1165,7 @@ }, { "cell_type": "markdown", - "id": "5d2739a2", + "id": "d47955f7", "metadata": { "tags": [] }, @@ -1085,7 +1176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "933e724b", + "id": "94284732", "metadata": { "lines_to_next_cell": 0 }, @@ -1100,7 +1191,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8367d7e7", + "id": "cb6f9edc", "metadata": {}, "outputs": [], "source": [ @@ -1112,7 +1203,7 @@ }, { "cell_type": "markdown", - "id": "28279f41", + "id": "8aba5707", "metadata": {}, "source": [ "
\n", @@ -1127,7 +1218,7 @@ }, { "cell_type": "markdown", - "id": "db7e8748", + "id": "b9713122", "metadata": {}, "source": [ "

Checkpoint 4

\n", @@ -1140,7 +1231,7 @@ }, { "cell_type": "markdown", - "id": "ca69811f", + "id": "183344be", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1148,7 +1239,7 @@ }, { "cell_type": "markdown", - "id": "3a84225c", + "id": "83417bff", "metadata": {}, "source": [ "At this point we have:\n", @@ -1163,7 +1254,7 @@ }, { "cell_type": "markdown", - "id": "fd9cd294", + "id": "737ae577", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1184,7 +1275,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb49e17c", + "id": "84c56d18", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1213,7 +1304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7659deb0", + "id": "37413116", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1245,7 +1336,7 @@ }, { "cell_type": "markdown", - "id": "df1543ab", + "id": "8737c833", "metadata": { "tags": [] }, @@ -1256,7 +1347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31a46e04", + "id": "ee8f6090", "metadata": { "tags": [] }, @@ -1268,7 +1359,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e54ae384", + "id": "b33a0107", "metadata": { "tags": [] }, @@ -1289,7 +1380,7 @@ }, { "cell_type": "markdown", - "id": "ccbc04c1", + "id": "2edae8d4", "metadata": { "tags": [] }, @@ -1305,7 +1396,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53050f11", + "id": "79d46ed5", "metadata": { "tags": [] }, @@ -1318,7 +1409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c71cb0f8", + "id": "0ec9b3cf", "metadata": { "tags": [] }, @@ -1349,7 +1440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76caab37", + "id": "c387ba61", "metadata": {}, "outputs": [], "source": [] @@ -1357,7 +1448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35991baf", + "id": "8b9e843e", "metadata": { "tags": [] }, @@ -1438,7 +1529,7 @@ }, { "cell_type": "markdown", - "id": "a270e2d8", + "id": "837d2a6a", "metadata": { "tags": [] }, @@ -1454,7 +1545,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34e7801c", + "id": "01f878a8", "metadata": { "tags": [] }, @@ -1465,7 +1556,7 @@ }, { "cell_type": "markdown", - "id": "e4009e6f", + "id": "28aceac4", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1474,7 +1565,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7adaa4d4", + "id": "2ae84d44", "metadata": { "tags": [] }, @@ -1490,7 +1581,7 @@ }, { "cell_type": "markdown", - "id": "33544547", + "id": "8ff5ceb0", "metadata": { "tags": [] }, @@ -1508,7 +1599,7 @@ }, { "cell_type": "markdown", - "id": "4ed9c11a", + "id": "ca976c6b", "metadata": { "tags": [] }, @@ -1521,7 +1612,7 @@ }, { "cell_type": "markdown", - "id": "7a1577b8", + "id": "bd96b144", "metadata": { "tags": [] },