From b3d267dd21d7de906048ef7050add5ff78b40ed2 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Mon, 12 Aug 2024 14:33:38 -0400 Subject: [PATCH] wip: Add GAN training task --- solution.py | 159 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 126 insertions(+), 33 deletions(-) diff --git a/solution.py b/solution.py index bc0e50e..dab7902 100644 --- a/solution.py +++ b/solution.py @@ -397,7 +397,7 @@ def forward(self, x, y): # We are going to create the models for the generator, discriminator, and style mapping. # # Given the Generator structure above, fill in the missing parts for the unet and the style mapping. -# %% +# %% tags=["task"] style_size = ... # TODO choose a size for the style space unet_depth = ... # TODO Choose a depth for the UNet style_mapping = DenseModel( @@ -428,7 +428,7 @@ def forward(self, x, y): # The discriminator will take as input either a real image or a fake image. # Fill in the following code to create a discriminator that can classify the images into the correct number of classes. # -# %% tags=[] +# %% tags=["task"] discriminator = DenseModel(input_shape=..., num_classes=...) # %% tags=["solution"] discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4) @@ -460,7 +460,7 @@ def forward(self, x, y): # The adversarial loss will be applied differently to the generator and the discriminator! Be very careful! # # %% -adverial_loss_fn = nn.CrossEntropyLoss() +adversarial_loss_fn = nn.CrossEntropyLoss() # %% [markdown] tags=[] # @@ -469,47 +469,135 @@ def forward(self, x, y): # Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes. # The cycle loss is applied only to the generator. # +# %% cycle_loss_fn = nn.L1Loss() +# %% [markdown] tags=[] +# Stuff about the dataloader + # %% +from torch.utils.data import DataLoader + +dataloader = DataLoader( + mnist, batch_size=32, drop_last=True, shuffle=True +) # We will use the same dataset as before + +# %% [markdown] tags=[] +# TODO - Describe set_requires_grad + + +# %% +def set_requires_grad(module, value=True): + """Sets `requires_grad` on a `module`'s parameters to `value`""" + for param in module.parameters(): + param.requires_grad = value + # %% [markdown] tags=[] #

Task 3.2: Training!

-# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on. # +# TODO - the task is to choose where to apply set_requires_grad +# +# Let's train the StarGAN one batch a time. # While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take? #
+# %% tags=["task"] +from tqdm import tqdm # This is a nice library for showing progress bars -# %% [markdown] tags=[] -# ...this time again. -# -# drawing -# -# %% -# TODO also turn this into a standalone script for use during the project phase -from torch.utils.data import DataLoader -from tqdm import tqdm +losses = {"cycle": [], "adv": [], "disc": []} +for epoch in range(15): + for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): + x = x.to(device) + y = y.to(device) + # get the target y by shuffling the classes + # get the style sources by random sampling + random_index = torch.randperm(len(y)) + x_style = x[random_index].clone() + y_target = y[random_index].clone() -def set_requires_grad(module, value=True): - """Sets `requires_grad` on a `module`'s parameters to `value`""" - for param in module.parameters(): - param.requires_grad = value + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + optimizer_g.zero_grad() + # Get the fake image + x_fake = generator(x, x_style) + # Try to cycle back + x_cycled = generator(x_fake, x) + # Discriminate + discriminator_x_fake = discriminator(x_fake) + # Losses to train the generator -cycle_loss_fn = nn.L1Loss() -class_loss_fn = nn.CrossEntropyLoss() + # 1. make sure the image can be reconstructed + cycle_loss = cycle_loss_fn(x, x_cycled) + # 2. make sure the discriminator is fooled + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) -optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6) -optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4) + # Optimize the generator + (cycle_loss + adv_loss).backward() + optimizer_g.step() + + # TODO - Choose an option by commenting out what you don't want + ############ + # Option 1 # + ############ + set_requires_grad(generator, True) + set_requires_grad(discriminator, False) + ############ + # Option 2 # + ############ + set_requires_grad(generator, False) + set_requires_grad(discriminator, True) + # + optimizer_d.zero_grad() + # + discriminator_x = discriminator(x) + discriminator_x_fake = discriminator(x_fake.detach()) + + # TODO - Choose an option by commenting out what you don't want + # Losses to train the discriminator + # 1. make sure the discriminator can tell real is real + # 2. make sure the discriminator can tell fake is fake + ############ + # Option 1 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) + ############ + # Option 2 # + ############ + real_loss = adversarial_loss_fn(discriminator_x, y) + fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target) + # + disc_loss = (real_loss + fake_loss) * 0.5 + disc_loss.backward() + # Optimize the discriminator + optimizer_d.step() + + losses["cycle"].append(cycle_loss.item()) + losses["adv"].append(adv_loss.item()) + losses["disc"].append(disc_loss.item()) + +# %% tags=["solution"] +from tqdm import tqdm # This is a nice library for showing progress bars -dataloader = DataLoader( - mnist, batch_size=32, drop_last=True, shuffle=True -) # We will use the same dataset as before losses = {"cycle": [], "adv": [], "disc": []} -for epoch in range(50): +for epoch in range(15): for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"): x = x.to(device) y = y.to(device) @@ -533,7 +621,7 @@ def set_requires_grad(module, value=True): # 1. make sure the image can be reconstructed cycle_loss = cycle_loss_fn(x, x_cycled) # 2. make sure the discriminator is fooled - adv_loss = class_loss_fn(discriminator_x_fake, y_target) + adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target) # Optimize the generator (cycle_loss + adv_loss).backward() @@ -547,9 +635,9 @@ def set_requires_grad(module, value=True): discriminator_x_fake = discriminator(x_fake.detach()) # Losses to train the discriminator # 1. make sure the discriminator can tell real is real - real_loss = class_loss_fn(discriminator_x, y) - # 2. make sure the discriminator can't tell fake is fake - fake_loss = -class_loss_fn(discriminator_x_fake, y_target) + real_loss = adversarial_loss_fn(discriminator_x, y) + # 2. make sure the discriminator can tell fake is fake + fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target) # disc_loss = (real_loss + fake_loss) * 0.5 disc_loss.backward() @@ -560,15 +648,20 @@ def set_requires_grad(module, value=True): losses["adv"].append(adv_loss.item()) losses["disc"].append(disc_loss.item()) + +# %% [markdown] tags=[] +# ...this time again. 🚂 🚋 🚋 🚋 +# +# Once training is complete, we can plot the losses to see how well the model is doing. # %% plt.plot(losses["cycle"], label="Cycle loss") plt.plot(losses["adv"], label="Adversarial loss") plt.plot(losses["disc"], label="Discriminator loss") plt.legend() plt.show() -# %% [markdown] tags=[] -# Let's add a quick plotting function before we begin training... +# %% [markdown] tags=[] +# We can also look at some examples of the images that the generator is creating. # %% idx = 0 fig, axs = plt.subplots(1, 4, figsize=(12, 4)) @@ -581,8 +674,8 @@ def set_requires_grad(module, value=True): ax.axis("off") plt.show() -# TODO WIP here - +# %% +# TODO wip here # %% [markdown] tags=[] #

Checkpoint 3

# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.