From 03e6aaa35b0362c6fb0f72c7aa92abbbc97b44ce Mon Sep 17 00:00:00 2001 From: adjavon Date: Mon, 12 Aug 2024 20:02:07 +0000 Subject: [PATCH] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 591 +++++++++++++++++++++++++---------------------- solution.ipynb | 605 ++++++++++++++++++++++++++----------------------- 2 files changed, 641 insertions(+), 555 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 41bfda6..4998440 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b3ddb066", + "id": "eab4778f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "7f43c1e3", + "id": "c62087c9", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "a9aaf840", + "id": "43cb388c", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6c5bc0", + "id": "37c4f359", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "5dd19fe5", + "id": "f4f0b771", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c709838", + "id": "2748b7dc", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "6b04f969", + "id": "3d712049", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "27ea9906", + "id": "21a9fe70", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3bf64cda", + "id": "2e7a7de0", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,44 @@ }, { "cell_type": "markdown", - "id": "2ad014ac", + "id": "cecfa46d", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Don't take my word for it! Let's see how well the classifier does on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93253d6", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "from sklearn.metrics import confusion_matrix\n", + "import seaborn as sns\n", + "\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", + "\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", + "\n", + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "id": "426d8618", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -165,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "e40eeba7", + "id": "dc39b0d7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -178,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7aca710", + "id": "39661efa", "metadata": { "tags": [] }, @@ -194,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "44d286aa", + "id": "ec39c8fe", "metadata": { "tags": [] }, @@ -210,7 +247,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ec387f82", + "id": "f884ed8b", "metadata": { "tags": [ "task" @@ -231,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea56240", + "id": "48d39aca", "metadata": { "tags": [] }, @@ -244,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "bc50850e", + "id": "7ceb951f", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -256,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7447933", + "id": "5deccc78", "metadata": { "tags": [] }, @@ -284,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cb527bf", + "id": "59d12539", "metadata": { "tags": [] }, @@ -296,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "25ecec3e", + "id": "88ad18f6", "metadata": { "lines_to_next_cell": 2 }, @@ -310,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "2b43f05d", + "id": "631be1d6", "metadata": { "lines_to_next_cell": 0 }, @@ -323,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7b21894", + "id": "13ffacb0", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +384,7 @@ }, { "cell_type": "markdown", - "id": "f97eace2", + "id": "db5e1b05", "metadata": { "lines_to_next_cell": 0 }, @@ -361,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "ed5b7d6e", + "id": "bbd4268a", "metadata": {}, "source": [ "\n", @@ -387,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "bf13ae8d", + "id": "d382b20b", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -399,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60d6691c", + "id": "660863df", "metadata": { "tags": [ "task" @@ -419,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "f24c00a3", + "id": "c1eb0219", "metadata": { "tags": [] }, @@ -433,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3835de1a", + "id": "c56b4eb8", "metadata": { "tags": [ "task" @@ -455,7 +492,7 @@ }, { "cell_type": "markdown", - "id": "10a6cfcc", + "id": "1176883b", "metadata": { "tags": [] }, @@ -471,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "25f3d08e", + "id": "30b0ecb9", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -485,7 +522,7 @@ }, { "cell_type": "markdown", - "id": "65d946a8", + "id": "accb5960", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -505,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "04602cf9", + "id": "aa54fc73", "metadata": { "lines_to_next_cell": 0 }, @@ -533,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "ed173d7c", + "id": "b72ac61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -556,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "03a51bad", + "id": "0cf84860", "metadata": {}, "outputs": [], "source": [ @@ -588,7 +625,7 @@ }, { "cell_type": "markdown", - "id": "e6c6168d", + "id": "b7126106", "metadata": { "lines_to_next_cell": 0 }, @@ -603,9 +640,12 @@ { "cell_type": "code", "execution_count": null, - "id": "9d0ef49f", + "id": "75766e24", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -621,7 +661,7 @@ }, { "cell_type": "markdown", - "id": "bd761ef3", + "id": "c0b9a3b5", "metadata": { "tags": [] }, @@ -636,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "d1220bb6", + "id": "d2d19ccb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -653,10 +693,12 @@ { "cell_type": "code", "execution_count": null, - "id": "71482197", + "id": "379a1c73", "metadata": { "lines_to_next_cell": 0, - "tags": [] + "tags": [ + "task" + ] }, "outputs": [], "source": [ @@ -665,7 +707,7 @@ }, { "cell_type": "markdown", - "id": "709affba", + "id": "c2761ac5", "metadata": { "lines_to_next_cell": 0 }, @@ -676,7 +718,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7059545e", + "id": "df419c3c", "metadata": {}, "outputs": [], "source": [ @@ -686,7 +728,7 @@ }, { "cell_type": "markdown", - "id": "b1a7581c", + "id": "9b4e8069", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -704,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7805887e", + "id": "07fb5440", "metadata": { "lines_to_next_cell": 0 }, @@ -716,7 +758,7 @@ }, { "cell_type": "markdown", - "id": "1bad28d8", + "id": "4f4f88ce", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -735,17 +777,18 @@ { "cell_type": "code", "execution_count": null, - "id": "a757512e", + "id": "eae1b681", "metadata": {}, "outputs": [], "source": [ - "adverial_loss_fn = nn.CrossEntropyLoss()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "5c590737", + "id": "d45aa99e", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -753,77 +796,105 @@ "**Cycle/reconstruction loss**\n", "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", - "The cycle loss is applied only to the generator.\n", - "\n", - "cycle_loss_fn = nn.L1Loss()" + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "0def44d4", + "id": "c20c35b7", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "cycle_loss_fn = nn.L1Loss()" + ] }, { "cell_type": "markdown", - "id": "3a0c1d2e", + "id": "6d10813e", "metadata": { - "lines_to_next_cell": 2, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", + "Stuff about the dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337c819", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
" + "dataloader = DataLoader(\n", + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { "cell_type": "markdown", - "id": "9f577571", + "id": "feb14b16", "metadata": { - "lines_to_next_cell": 0, + "lines_to_next_cell": 2, "tags": [] }, "source": [ - "...this time again.\n", - "\n", - "\"drawing\"\n" + "TODO - Describe set_requires_grad" ] }, { "cell_type": "code", "execution_count": null, - "id": "d3077e49", + "id": "21f19dc7", "metadata": {}, "outputs": [], "source": [ - "# TODO also turn this into a standalone script for use during the project phase\n", - "from torch.utils.data import DataLoader\n", - "from tqdm import tqdm\n", - "\n", - "\n", "def set_requires_grad(module, value=True):\n", " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "cycle_loss_fn = nn.L1Loss()\n", - "class_loss_fn = nn.CrossEntropyLoss()\n", + " param.requires_grad = value" + ] + }, + { + "cell_type": "markdown", + "id": "58161b77", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "

Task 3.2: Training!

\n", "\n", - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", - "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "TODO - the task is to choose where to apply set_requires_grad\n", + "
    \n", + "
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • \n", + "
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • \n", + "
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • \n", + "
\n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc4f6fbc", + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "dataloader = DataLoader(\n", - " mnist, batch_size=32, drop_last=True, shuffle=True\n", - ") # We will use the same dataset as before\n", "\n", "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", - "for epoch in range(50):\n", + "\n", + "for epoch in range(15):\n", " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", " x = x.to(device)\n", " y = y.to(device)\n", @@ -833,8 +904,18 @@ " x_style = x[random_index].clone()\n", " y_target = y[random_index].clone()\n", "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", " set_requires_grad(generator, True)\n", " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " set_requires_grad(generator, False)\n", + " set_requires_grad(discriminator, True)\n", + "\n", " optimizer_g.zero_grad()\n", " # Get the fake image\n", " x_fake = generator(x, x_style)\n", @@ -847,23 +928,43 @@ " # 1. make sure the image can be reconstructed\n", " cycle_loss = cycle_loss_fn(x, x_cycled)\n", " # 2. make sure the discriminator is fooled\n", - " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", "\n", " # Optimize the generator\n", " (cycle_loss + adv_loss).backward()\n", " optimizer_g.step()\n", "\n", + " # TODO - Choose an option by commenting out what you don't want\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " set_requires_grad(generator, True)\n", + " set_requires_grad(discriminator, False)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", " set_requires_grad(generator, False)\n", " set_requires_grad(discriminator, True)\n", + " #\n", " optimizer_d.zero_grad()\n", - " # TODO Do I need to re-do the forward pass?\n", + " #\n", " discriminator_x = discriminator(x)\n", " discriminator_x_fake = discriminator(x_fake.detach())\n", + "\n", + " # TODO - Choose an option by commenting out what you don't want\n", " # Losses to train the discriminator\n", " # 1. make sure the discriminator can tell real is real\n", - " real_loss = class_loss_fn(discriminator_x, y)\n", - " # 2. make sure the discriminator can't tell fake is fake\n", - " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " ############\n", + " # Option 1 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", + " ############\n", + " # Option 2 #\n", + " ############\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", " #\n", " disc_loss = (real_loss + fake_loss) * 0.5\n", " disc_loss.backward()\n", @@ -876,12 +977,23 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b232bd07", + "cell_type": "markdown", + "id": "99753362", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [] }, + "source": [ + "...this time again. 🚂 🚋 🚋 🚋\n", + "\n", + "Once training is complete, we can plot the losses to see how well the model is doing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99070716", + "metadata": {}, "outputs": [], "source": [ "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", @@ -893,18 +1005,19 @@ }, { "cell_type": "markdown", - "id": "16de7380", + "id": "ce337ff3", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "856af9da", + "id": "5d2443f5", "metadata": {}, "outputs": [], "source": [ @@ -917,30 +1030,38 @@ "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", - "plt.show()\n", - "\n", - "# TODO WIP here" + "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "726f77db", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "f7240ca5", + "id": "ed4e3ca8", "metadata": { "tags": [] }, "source": [ "

Checkpoint 3

\n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", "\n", - "You know the drill... let us know on the exercise chat!\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", "
" ] }, { "cell_type": "markdown", - "id": "67168867", + "id": "f77b54db", "metadata": { "tags": [] }, @@ -950,243 +1071,181 @@ }, { "cell_type": "markdown", - "id": "c6bdbfde", + "id": "cd268191", "metadata": { "tags": [] }, "source": [ + "## Creating counterfactuals\n", "\n", - "## That was fun!... let's load a pre-trained model\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "a8543304", + "id": "3a4b48f7", "metadata": { - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "prototypes = {}\n", + "\n", "\n", - "# TODO load the pre-trained model" + "for i in range(4):\n", + " options = np.where(test_mnist.targets == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "940b48d6", + "id": "cf374cec", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b9425d2", - "metadata": { - "tags": [] - }, + "id": "55b9457b", + "metadata": {}, "outputs": [], "source": [ - "# TODO show some examples" + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" ] }, { "cell_type": "markdown", - "id": "42f81f13", + "id": "8883baa5", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We're going to apply the GAN to our test dataset." + "Now we need to use these prototypes to create counterfactual images!\n", + "TODO make a task here!" ] }, { "cell_type": "code", "execution_count": null, - "id": "33fbfc83", - "metadata": { - "tags": [] - }, + "id": "65460b37", + "metadata": {}, "outputs": [], "source": [ - "# TODO load the test dataset" - ] - }, - { - "cell_type": "markdown", - "id": "00ded88d", - "metadata": { - "tags": [] - }, - "source": [ - "## Evaluating the GAN\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n" - ] - }, - { - "cell_type": "markdown", - "id": "f7475dc3", - "metadata": { - "tags": [] - }, - "source": [ - "

Task 4.1 Get the classifier accuracy on CycleGAN outputs

\n", - "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = ...\n", + " # Create the counterfactual from the image and prototype\n", + " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(...)\n", "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
" + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { "cell_type": "markdown", - "id": "97a88ddb", + "id": "3b176c31", "metadata": { "lines_to_next_cell": 0, "tags": [] }, "source": [ - "
\n", - "We get the following accuracies:\n", - "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", - "\n", - "

Questions

\n", - "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f82fa67", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", - "counterfactuals, reconstructions, targets, labels = ..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b93db0b2", - "metadata": { - "lines_to_next_cell": 0, - "title": "[markwodn]" - }, - "outputs": [], - "source": [ - "# Evaluate the images" + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "5c7ccc7b", + "id": "a9709066", "metadata": {}, "outputs": [], "source": [ - "# TODO use the loaded classifier to evaluate the images\n", - "# Get the accuracies\n", - "def predict():\n", - " # TODO return predictions, labels\n", - " pass" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" ] }, { "cell_type": "markdown", - "id": "d47955f7", + "id": "51805f97", "metadata": { "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "

Questions

\n", + "
    \n", + "
  • How well is our GAN doing at creating counterfactual images?
  • \n", + "
  • Do you think that the prototypes used matter? Why or why not?
  • \n", + "
\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "94284732", + "cell_type": "markdown", + "id": "e767437a", "metadata": { - "lines_to_next_cell": 0 + "tags": [] }, - "outputs": [], - "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the counterfactual images\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb6f9edc", - "metadata": {}, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the real images, for comparison\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "markdown", - "id": "8aba5707", - "metadata": {}, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", + "Let's also plot some examples of the counterfactual images.\n", "\n", - "
" + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "b9713122", - "metadata": {}, + "id": "545bc176", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, "source": [ - "

Checkpoint 4

\n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", + "

Questions

\n", + "
    \n", + "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", + "
  • What is your hypothesis for the features that define each class?
  • \n", + "
\n", + "
\n", "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
" + "TODO wip here" ] }, { "cell_type": "markdown", - "id": "183344be", + "id": "069a2183", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1194,7 +1253,7 @@ }, { "cell_type": "markdown", - "id": "83417bff", + "id": "7b2c0480", "metadata": {}, "source": [ "At this point we have:\n", @@ -1209,7 +1268,7 @@ }, { "cell_type": "markdown", - "id": "737ae577", + "id": "81f91fa8", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1230,7 +1289,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c56d18", + "id": "18d4c038", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1258,7 +1317,7 @@ }, { "cell_type": "markdown", - "id": "8737c833", + "id": "b34b1014", "metadata": { "tags": [] }, @@ -1269,7 +1328,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee8f6090", + "id": "f95678e3", "metadata": { "tags": [] }, @@ -1281,7 +1340,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b33a0107", + "id": "17e89469", "metadata": { "tags": [] }, @@ -1302,7 +1361,7 @@ }, { "cell_type": "markdown", - "id": "2edae8d4", + "id": "13e5deff", "metadata": { "tags": [] }, @@ -1318,7 +1377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d46ed5", + "id": "13af9caa", "metadata": { "tags": [] }, @@ -1331,7 +1390,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ec9b3cf", + "id": "696dfe89", "metadata": { "tags": [] }, @@ -1362,7 +1421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c387ba61", + "id": "d3246960", "metadata": {}, "outputs": [], "source": [] @@ -1370,7 +1429,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b9e843e", + "id": "7720e77b", "metadata": { "tags": [] }, @@ -1451,7 +1510,7 @@ }, { "cell_type": "markdown", - "id": "837d2a6a", + "id": "43c02c9f", "metadata": { "tags": [] }, @@ -1467,7 +1526,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01f878a8", + "id": "4294368b", "metadata": { "tags": [] }, @@ -1478,7 +1537,7 @@ }, { "cell_type": "markdown", - "id": "28aceac4", + "id": "91185a47", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1487,7 +1546,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ae84d44", + "id": "95d17b88", "metadata": { "tags": [] }, @@ -1503,7 +1562,7 @@ }, { "cell_type": "markdown", - "id": "8ff5ceb0", + "id": "9e017ac3", "metadata": { "tags": [] }, @@ -1521,7 +1580,7 @@ }, { "cell_type": "markdown", - "id": "ca976c6b", + "id": "92d3a2f0", "metadata": { "tags": [] }, @@ -1534,7 +1593,7 @@ }, { "cell_type": "markdown", - "id": "bd96b144", + "id": "5478001b", "metadata": { "tags": [] }, diff --git a/solution.ipynb b/solution.ipynb index 231e6f7..d85e10d 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b3ddb066", + "id": "eab4778f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "7f43c1e3", + "id": "c62087c9", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "a9aaf840", + "id": "43cb388c", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6c5bc0", + "id": "37c4f359", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "5dd19fe5", + "id": "f4f0b771", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c709838", + "id": "2748b7dc", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "6b04f969", + "id": "3d712049", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "27ea9906", + "id": "21a9fe70", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6457422b", + "id": "07029615", "metadata": { "tags": [ "solution" @@ -154,7 +154,44 @@ }, { "cell_type": "markdown", - "id": "2ad014ac", + "id": "cecfa46d", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Don't take my word for it! Let's see how well the classifier does on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93253d6", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "from sklearn.metrics import confusion_matrix\n", + "import seaborn as sns\n", + "\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "dataloader = DataLoader(test_mnist, batch_size=32, shuffle=False)\n", + "\n", + "labels = []\n", + "predictions = []\n", + "for x, y in dataloader:\n", + " pred = model(x.to(device))\n", + " labels.extend(y.cpu().numpy())\n", + " predictions.extend(pred.argmax(dim=1).cpu().numpy())\n", + "\n", + "cm = confusion_matrix(labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cm, annot=True, fmt=\".2f\")" + ] + }, + { + "cell_type": "markdown", + "id": "426d8618", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -164,7 +201,7 @@ }, { "cell_type": "markdown", - "id": "e40eeba7", + "id": "dc39b0d7", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -177,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7aca710", + "id": "39661efa", "metadata": { "tags": [] }, @@ -193,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "44d286aa", + "id": "ec39c8fe", "metadata": { "tags": [] }, @@ -209,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55d4cbcc", + "id": "4a6a5200", "metadata": { "tags": [ "solution" @@ -233,7 +270,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ea56240", + "id": "48d39aca", "metadata": { "tags": [] }, @@ -246,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "bc50850e", + "id": "7ceb951f", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -258,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7447933", + "id": "5deccc78", "metadata": { "tags": [] }, @@ -286,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cb527bf", + "id": "59d12539", "metadata": { "tags": [] }, @@ -298,7 +335,7 @@ }, { "cell_type": "markdown", - "id": "25ecec3e", + "id": "88ad18f6", "metadata": { "lines_to_next_cell": 2 }, @@ -312,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "2b43f05d", + "id": "631be1d6", "metadata": { "lines_to_next_cell": 0 }, @@ -325,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7b21894", + "id": "13ffacb0", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +386,7 @@ }, { "cell_type": "markdown", - "id": "f97eace2", + "id": "db5e1b05", "metadata": { "lines_to_next_cell": 0 }, @@ -363,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "ed5b7d6e", + "id": "bbd4268a", "metadata": {}, "source": [ "\n", @@ -389,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "bf13ae8d", + "id": "d382b20b", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -401,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e85e3e4", + "id": "c91ab0cd", "metadata": { "tags": [ "solution" @@ -426,7 +463,7 @@ }, { "cell_type": "markdown", - "id": "f24c00a3", + "id": "c1eb0219", "metadata": { "tags": [] }, @@ -440,7 +477,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12743143", + "id": "f3b761f9", "metadata": { "tags": [ "solution" @@ -467,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "10a6cfcc", + "id": "1176883b", "metadata": { "tags": [] }, @@ -483,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "25f3d08e", + "id": "30b0ecb9", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -497,7 +534,7 @@ }, { "cell_type": "markdown", - "id": "65d946a8", + "id": "accb5960", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -517,7 +554,7 @@ }, { "cell_type": "markdown", - "id": "04602cf9", + "id": "aa54fc73", "metadata": { "lines_to_next_cell": 0 }, @@ -545,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "ed173d7c", + "id": "b72ac61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -568,7 +605,7 @@ { "cell_type": "code", "execution_count": null, - "id": "03a51bad", + "id": "0cf84860", "metadata": {}, "outputs": [], "source": [ @@ -600,7 +637,7 @@ }, { "cell_type": "markdown", - "id": "e6c6168d", + "id": "b7126106", "metadata": { "lines_to_next_cell": 0 }, @@ -615,26 +652,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d0ef49f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "style_size = ... # TODO choose a size for the style space\n", - "unet_depth = ... # TODO Choose a depth for the UNet\n", - "style_mapping = DenseModel(\n", - " input_shape=..., num_classes=... # How big is the style space?\n", - ")\n", - "unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())\n", - "\n", - "generator = Generator(unet, style_mapping=style_mapping)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff22f753", + "id": "2e7dd95c", "metadata": { "tags": [ "solution" @@ -651,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "bd761ef3", + "id": "c0b9a3b5", "metadata": { "tags": [] }, @@ -666,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "d1220bb6", + "id": "d2d19ccb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -683,20 +701,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71482197", - "metadata": { - "lines_to_next_cell": 0, - "tags": [] - }, - "outputs": [], - "source": [ - "discriminator = DenseModel(input_shape=..., num_classes=...)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ef652d9", + "id": "5f596a72", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -710,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "709affba", + "id": "c2761ac5", "metadata": { "lines_to_next_cell": 0 }, @@ -721,7 +726,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7059545e", + "id": "df419c3c", "metadata": {}, "outputs": [], "source": [ @@ -731,7 +736,7 @@ }, { "cell_type": "markdown", - "id": "b1a7581c", + "id": "9b4e8069", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -749,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7805887e", + "id": "07fb5440", "metadata": { "lines_to_next_cell": 0 }, @@ -761,7 +766,7 @@ }, { "cell_type": "markdown", - "id": "1bad28d8", + "id": "4f4f88ce", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -780,17 +785,18 @@ { "cell_type": "code", "execution_count": null, - "id": "a757512e", + "id": "eae1b681", "metadata": {}, "outputs": [], "source": [ - "adverial_loss_fn = nn.CrossEntropyLoss()" + "adversarial_loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", - "id": "5c590737", + "id": "d45aa99e", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ @@ -798,77 +804,105 @@ "**Cycle/reconstruction loss**\n", "The cycle loss is there to make sure that the generator doesn't output an image that looks nothing like the input!\n", "Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.\n", - "The cycle loss is applied only to the generator.\n", - "\n", - "cycle_loss_fn = nn.L1Loss()" + "The cycle loss is applied only to the generator.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "0def44d4", + "id": "c20c35b7", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "cycle_loss_fn = nn.L1Loss()" + ] }, { "cell_type": "markdown", - "id": "3a0c1d2e", + "id": "6d10813e", "metadata": { - "lines_to_next_cell": 2, "tags": [] }, "source": [ - "

Task 3.2: Training!

\n", - "Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.\n", + "Stuff about the dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337c819", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", "\n", - "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", - "
" + "dataloader = DataLoader(\n", + " mnist, batch_size=32, drop_last=True, shuffle=True\n", + ") # We will use the same dataset as before" ] }, { "cell_type": "markdown", - "id": "9f577571", + "id": "feb14b16", "metadata": { - "lines_to_next_cell": 0, + "lines_to_next_cell": 2, "tags": [] }, "source": [ - "...this time again.\n", - "\n", - "\"drawing\"\n" + "TODO - Describe set_requires_grad" ] }, { "cell_type": "code", "execution_count": null, - "id": "d3077e49", + "id": "21f19dc7", "metadata": {}, "outputs": [], "source": [ - "# TODO also turn this into a standalone script for use during the project phase\n", - "from torch.utils.data import DataLoader\n", - "from tqdm import tqdm\n", - "\n", - "\n", "def set_requires_grad(module, value=True):\n", " \"\"\"Sets `requires_grad` on a `module`'s parameters to `value`\"\"\"\n", " for param in module.parameters():\n", - " param.requires_grad = value\n", - "\n", - "\n", - "cycle_loss_fn = nn.L1Loss()\n", - "class_loss_fn = nn.CrossEntropyLoss()\n", + " param.requires_grad = value" + ] + }, + { + "cell_type": "markdown", + "id": "58161b77", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, + "source": [ + "

Task 3.2: Training!

\n", "\n", - "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)\n", - "optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)\n", + "TODO - the task is to choose where to apply set_requires_grad\n", + "
    \n", + "
  • Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator
  • \n", + "
  • Choose the values of `set_requires_grad`, again. Hint: you may want to switch
  • \n", + "
  • Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?
  • \n", + "
\n", + "Let's train the StarGAN one batch a time.\n", + "While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934d3c68", + "metadata": { + "lines_to_next_cell": 2, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "from tqdm import tqdm # This is a nice library for showing progress bars\n", "\n", - "dataloader = DataLoader(\n", - " mnist, batch_size=32, drop_last=True, shuffle=True\n", - ") # We will use the same dataset as before\n", "\n", "losses = {\"cycle\": [], \"adv\": [], \"disc\": []}\n", - "for epoch in range(50):\n", + "for epoch in range(15):\n", " for x, y in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n", " x = x.to(device)\n", " y = y.to(device)\n", @@ -892,7 +926,7 @@ " # 1. make sure the image can be reconstructed\n", " cycle_loss = cycle_loss_fn(x, x_cycled)\n", " # 2. make sure the discriminator is fooled\n", - " adv_loss = class_loss_fn(discriminator_x_fake, y_target)\n", + " adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)\n", "\n", " # Optimize the generator\n", " (cycle_loss + adv_loss).backward()\n", @@ -906,9 +940,9 @@ " discriminator_x_fake = discriminator(x_fake.detach())\n", " # Losses to train the discriminator\n", " # 1. make sure the discriminator can tell real is real\n", - " real_loss = class_loss_fn(discriminator_x, y)\n", - " # 2. make sure the discriminator can't tell fake is fake\n", - " fake_loss = -class_loss_fn(discriminator_x_fake, y_target)\n", + " real_loss = adversarial_loss_fn(discriminator_x, y)\n", + " # 2. make sure the discriminator can tell fake is fake\n", + " fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)\n", " #\n", " disc_loss = (real_loss + fake_loss) * 0.5\n", " disc_loss.backward()\n", @@ -921,12 +955,23 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b232bd07", + "cell_type": "markdown", + "id": "99753362", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 0, + "tags": [] }, + "source": [ + "...this time again. 🚂 🚋 🚋 🚋\n", + "\n", + "Once training is complete, we can plot the losses to see how well the model is doing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99070716", + "metadata": {}, "outputs": [], "source": [ "plt.plot(losses[\"cycle\"], label=\"Cycle loss\")\n", @@ -938,18 +983,19 @@ }, { "cell_type": "markdown", - "id": "16de7380", + "id": "ce337ff3", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's add a quick plotting function before we begin training..." + "We can also look at some examples of the images that the generator is creating." ] }, { "cell_type": "code", "execution_count": null, - "id": "856af9da", + "id": "5d2443f5", "metadata": {}, "outputs": [], "source": [ @@ -962,30 +1008,38 @@ "\n", "for ax in axs:\n", " ax.axis(\"off\")\n", - "plt.show()\n", - "\n", - "# TODO WIP here" + "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "726f77db", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", - "id": "f7240ca5", + "id": "ed4e3ca8", "metadata": { "tags": [] }, "source": [ "

Checkpoint 3

\n", - "You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.\n", - "The same method can be used to create a CycleGAN with different basic elements.\n", + "You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.\n", + "The same method can be used to create a StarGAN with different basic elements.\n", "For example, you can change the archictecture of the generators, or of the discriminator to better fit your data in the future.\n", "\n", - "You know the drill... let us know on the exercise chat!\n", + "You know the drill... let us know on the exercise chat when you have arrived here!\n", "
" ] }, { "cell_type": "markdown", - "id": "67168867", + "id": "f77b54db", "metadata": { "tags": [] }, @@ -995,243 +1049,216 @@ }, { "cell_type": "markdown", - "id": "c6bdbfde", + "id": "cd268191", "metadata": { "tags": [] }, "source": [ + "## Creating counterfactuals\n", "\n", - "## That was fun!... let's load a pre-trained model\n", + "The first thing that we want to do is make sure that our GAN is able to create counterfactual images.\n", + "To do this, we have to create them, and then pass them through the classifier to see if they are classified correctly.\n", "\n", - "Training the CycleGAN takes a lot longer than the few iterations that we did above. Since we don't have that kind of time, we are going to load a pre-trained model (for reference, this pre-trained model was trained for 7 days...).\n", - "\n", - "To continue, interrupt the kernel and continue with the next one, which will just use one of the pretrained CycleGAN models for the synapse dataset." + "First, let's get the test dataset, so we can evaluate the GAN on unseen data.\n", + "Then, let's get four prototypical images from the dataset as style sources." ] }, { "cell_type": "code", "execution_count": null, - "id": "a8543304", + "id": "3a4b48f7", "metadata": { - "tags": [] + "title": "Loading the test dataset" }, "outputs": [], "source": [ - "from pathlib import Path\n", - "import torch\n", + "test_mnist = ColoredMNIST(\"data\", download=True, train=False)\n", + "prototypes = {}\n", "\n", - "# TODO load the pre-trained model" + "\n", + "for i in range(4):\n", + " options = np.where(test_mnist.targets == i)[0]\n", + " # Note that you can change the image index if you want to use a different prototype.\n", + " image_index = 0\n", + " x, y = test_mnist[options[image_index]]\n", + " prototypes[i] = x" ] }, { "cell_type": "markdown", - "id": "940b48d6", + "id": "cf374cec", "metadata": { + "lines_to_next_cell": 0, "tags": [] }, "source": [ - "Let's look at some examples. Can you pick up on the differences between original, the counter-factual, and the reconstruction?" + "Let's have a look at the prototypes." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b9425d2", - "metadata": { - "tags": [] - }, + "id": "55b9457b", + "metadata": {}, "outputs": [], "source": [ - "# TODO show some examples" + "fig, axs = plt.subplots(1, 4, figsize=(12, 4))\n", + "for i, ax in enumerate(axs):\n", + " ax.imshow(prototypes[i].permute(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Prototype {i}\")" ] }, { "cell_type": "markdown", - "id": "42f81f13", + "id": "8883baa5", "metadata": { - "tags": [] + "lines_to_next_cell": 0 }, "source": [ - "We're going to apply the GAN to our test dataset." + "Now we need to use these prototypes to create counterfactual images!\n", + "TODO make a task here!" ] }, { "cell_type": "code", "execution_count": null, - "id": "33fbfc83", - "metadata": { - "tags": [] - }, + "id": "65460b37", + "metadata": {}, "outputs": [], "source": [ - "# TODO load the test dataset" - ] - }, - { - "cell_type": "markdown", - "id": "00ded88d", - "metadata": { - "tags": [] - }, - "source": [ - "## Evaluating the GAN\n", - "\n", - "The first thing to find out is whether the CycleGAN is successfully converting the images from one neurotransmitter to another.\n", - "We will do this by running the classifier that we trained earlier on generated data.\n" - ] - }, - { - "cell_type": "markdown", - "id": "f7475dc3", - "metadata": { - "tags": [] - }, - "source": [ - "

Task 4.1 Get the classifier accuracy on CycleGAN outputs

\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "Using the saved images, we're going to figure out how good our CycleGAN is at generating images of a new class!\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "The images (`real`, `reconstructed`, and `counterfactual`) are saved in the `test_images/` directory. Before you start the exercise, have a look at how this directory is organized.\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = ...\n", + " # Create the counterfactual from the image and prototype\n", + " x_fake = generator(x.unsqueeze(0).to(device), ...)\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(...)\n", "\n", - "TODO\n", - "- Use the `make_dataset` function to create a dataset for the three different image types that we saved above\n", - " - real\n", - " - reconstructed\n", - " - counterfactual\n", - "
" + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { - "cell_type": "markdown", - "id": "97a88ddb", + "cell_type": "code", + "execution_count": null, + "id": "7da0a992", "metadata": { - "lines_to_next_cell": 0, - "tags": [] + "tags": [ + "solution" + ] }, + "outputs": [], "source": [ - "
\n", - "We get the following accuracies:\n", + "num_images = len(test_mnist)\n", + "counterfactuals = np.zeros((4, num_images, 3, 28, 28))\n", "\n", - "1. `accuracy_real`: Accuracy of the classifier on the real images, just for the two classes used in the GAN\n", - "2. `accuracy_recon`: Accuracy of the classifier on the reconstruction.\n", - "3. `accuracy_counter`: Accuracy of the classifier on the counterfactual images.\n", + "predictions = []\n", + "source_labels = []\n", + "target_labels = []\n", "\n", - "

Questions

\n", + "for x, y in test_mnist:\n", + " for i in range(4):\n", + " if i == y:\n", + " # Store the image as is.\n", + " counterfactuals[i] = x\n", + " # Create the counterfactual\n", + " x_fake = generator(\n", + " x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)\n", + " )\n", + " counterfactuals[i] = x_fake.cpu().detach().numpy()\n", + " pred = model(x_fake)\n", "\n", - "- In a perfect world, what value would we expect for `accuracy_recon`? What do we compare it to and why is it higher/lower?\n", - "- How well is it translating from one class to another? Do we expect `accuracy_counter` to be large or small? Do we want it to be large or small? Why?\n", - "\n", - "Let us know your insights on the exercise chat.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f82fa67", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO make a loop on the data that creates the counterfactual images, given a set of options as input\n", - "counterfactuals, reconstructions, targets, labels = ..." + " source_labels.append(y)\n", + " target_labels.append(i)\n", + " predictions.append(pred.argmax().item())" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b93db0b2", + "cell_type": "markdown", + "id": "3b176c31", "metadata": { "lines_to_next_cell": 0, - "title": "[markwodn]" + "tags": [] }, - "outputs": [], "source": [ - "# Evaluate the images" + "Let's plot the confusion matrix for the counterfactual images." ] }, { "cell_type": "code", "execution_count": null, - "id": "5c7ccc7b", + "id": "a9709066", "metadata": {}, "outputs": [], "source": [ - "# TODO use the loaded classifier to evaluate the images\n", - "# Get the accuracies\n", - "def predict():\n", - " # TODO return predictions, labels\n", - " pass" + "cf_cm = confusion_matrix(target_labels, predictions, normalize=\"true\")\n", + "sns.heatmap(cf_cm, annot=True, fmt=\".2f\")" ] }, { "cell_type": "markdown", - "id": "d47955f7", + "id": "51805f97", "metadata": { "tags": [] }, "source": [ - "We're going to look at the confusion matrices for the counterfactuals, and compare it to that of the real images." + "

Questions

\n", + "
    \n", + "
  • How well is our GAN doing at creating counterfactual images?
  • \n", + "
  • Do you think that the prototypes used matter? Why or why not?
  • \n", + "
\n", + "
" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "94284732", + "cell_type": "markdown", + "id": "e767437a", "metadata": { - "lines_to_next_cell": 0 + "tags": [] }, - "outputs": [], - "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the counterfactual images\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb6f9edc", - "metadata": {}, - "outputs": [], "source": [ - "print(\"The confusion matrix on the real images... for comparison\")\n", - "# TODO Confusion matrix on the real images, for comparison\n", - "confusion_matrix = ...\n", - "# TODO plot" - ] - }, - { - "cell_type": "markdown", - "id": "8aba5707", - "metadata": {}, - "source": [ - "
\n", - "

Questions

\n", - "\n", - "- What would you expect the confusion matrix for the counterfactuals to look like? Why?\n", - "- Do the two directions of the CycleGAN work equally as well?\n", - "- Can you think of anything that might have made it more difficult, or easier, to translate in a one direction vs the other?\n", + "Let's also plot some examples of the counterfactual images.\n", "\n", - "
" + "for i in np.random.choice(range(num_images), 4):\n", + " fig, axs = plt.subplots(1, 4, figsize=(20, 4))\n", + " for j, ax in enumerate(axs):\n", + " ax.imshow(counterfactuals[j][i].transpose(1, 2, 0))\n", + " ax.axis(\"off\")\n", + " ax.set_title(f\"Class {j}\")" ] }, { "cell_type": "markdown", - "id": "b9713122", - "metadata": {}, + "id": "545bc176", + "metadata": { + "lines_to_next_cell": 0, + "tags": [] + }, "source": [ - "

Checkpoint 4

\n", - " We have seen that our CycleGAN network has successfully translated some of the synapses from one class to the other, but there are clearly some things to look out for!\n", - "Take the time to think about the questions above before moving on...\n", + "

Questions

\n", + "
    \n", + "
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • \n", + "
  • What is your hypothesis for the features that define each class?
  • \n", + "
\n", + "
\n", "\n", - "This is the end of Section 4. Let us know on the exercise chat if you have reached this point!\n", - "
" + "TODO wip here" ] }, { "cell_type": "markdown", - "id": "183344be", + "id": "069a2183", "metadata": {}, "source": [ "# Part 5: Highlighting Class-Relevant Differences" @@ -1239,7 +1266,7 @@ }, { "cell_type": "markdown", - "id": "83417bff", + "id": "7b2c0480", "metadata": {}, "source": [ "At this point we have:\n", @@ -1254,7 +1281,7 @@ }, { "cell_type": "markdown", - "id": "737ae577", + "id": "81f91fa8", "metadata": {}, "source": [ "

Task 5.1 Get sucessfully converted samples

\n", @@ -1275,7 +1302,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84c56d18", + "id": "18d4c038", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -1304,7 +1331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37413116", + "id": "338b7d53", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1336,7 +1363,7 @@ }, { "cell_type": "markdown", - "id": "8737c833", + "id": "b34b1014", "metadata": { "tags": [] }, @@ -1347,7 +1374,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee8f6090", + "id": "f95678e3", "metadata": { "tags": [] }, @@ -1359,7 +1386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b33a0107", + "id": "17e89469", "metadata": { "tags": [] }, @@ -1380,7 +1407,7 @@ }, { "cell_type": "markdown", - "id": "2edae8d4", + "id": "13e5deff", "metadata": { "tags": [] }, @@ -1396,7 +1423,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79d46ed5", + "id": "13af9caa", "metadata": { "tags": [] }, @@ -1409,7 +1436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ec9b3cf", + "id": "696dfe89", "metadata": { "tags": [] }, @@ -1440,7 +1467,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c387ba61", + "id": "d3246960", "metadata": {}, "outputs": [], "source": [] @@ -1448,7 +1475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b9e843e", + "id": "7720e77b", "metadata": { "tags": [] }, @@ -1529,7 +1556,7 @@ }, { "cell_type": "markdown", - "id": "837d2a6a", + "id": "43c02c9f", "metadata": { "tags": [] }, @@ -1545,7 +1572,7 @@ { "cell_type": "code", "execution_count": null, - "id": "01f878a8", + "id": "4294368b", "metadata": { "tags": [] }, @@ -1556,7 +1583,7 @@ }, { "cell_type": "markdown", - "id": "28aceac4", + "id": "91185a47", "metadata": {}, "source": [ "HELP!!! Interactive (still!) doesn't work. No worries... uncomment the following cell and choose your index and threshold by typing them out." @@ -1565,7 +1592,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ae84d44", + "id": "95d17b88", "metadata": { "tags": [] }, @@ -1581,7 +1608,7 @@ }, { "cell_type": "markdown", - "id": "8ff5ceb0", + "id": "9e017ac3", "metadata": { "tags": [] }, @@ -1599,7 +1626,7 @@ }, { "cell_type": "markdown", - "id": "ca976c6b", + "id": "92d3a2f0", "metadata": { "tags": [] }, @@ -1612,7 +1639,7 @@ }, { "cell_type": "markdown", - "id": "bd96b144", + "id": "5478001b", "metadata": { "tags": [] },