Skip to content

Commit

Permalink
wip: Add GAN training task
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Aug 12, 2024
1 parent 3d887bc commit b3d267d
Showing 1 changed file with 126 additions and 33 deletions.
159 changes: 126 additions & 33 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def forward(self, x, y):
# We are going to create the models for the generator, discriminator, and style mapping.
#
# Given the Generator structure above, fill in the missing parts for the unet and the style mapping.
# %%
# %% tags=["task"]
style_size = ... # TODO choose a size for the style space
unet_depth = ... # TODO Choose a depth for the UNet
style_mapping = DenseModel(
Expand Down Expand Up @@ -428,7 +428,7 @@ def forward(self, x, y):
# The discriminator will take as input either a real image or a fake image.
# Fill in the following code to create a discriminator that can classify the images into the correct number of classes.
# </div>
# %% tags=[]
# %% tags=["task"]
discriminator = DenseModel(input_shape=..., num_classes=...)
# %% tags=["solution"]
discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)
Expand Down Expand Up @@ -460,7 +460,7 @@ def forward(self, x, y):
# The adversarial loss will be applied differently to the generator and the discriminator! Be very careful!
# </div>
# %%
adverial_loss_fn = nn.CrossEntropyLoss()
adversarial_loss_fn = nn.CrossEntropyLoss()

# %% [markdown] tags=[]
#
Expand All @@ -469,47 +469,135 @@ def forward(self, x, y):
# Indeed, by training the generator to be able to cycle back to the original image, we are making sure that it makes a minimum number of changes.
# The cycle loss is applied only to the generator.
#
# %%
cycle_loss_fn = nn.L1Loss()

# %% [markdown] tags=[]
# Stuff about the dataloader

# %%
from torch.utils.data import DataLoader

dataloader = DataLoader(
mnist, batch_size=32, drop_last=True, shuffle=True
) # We will use the same dataset as before

# %% [markdown] tags=[]
# TODO - Describe set_requires_grad


# %%
def set_requires_grad(module, value=True):
"""Sets `requires_grad` on a `module`'s parameters to `value`"""
for param in module.parameters():
param.requires_grad = value


# %% [markdown] tags=[]
# <div class="alert alert-banner alert-info"><h4>Task 3.2: Training!</h4>
# Let's train the CycleGAN one batch a time, plotting the output every so often to see how it is getting on.
#
# TODO - the task is to choose where to apply set_requires_grad
# <ul>
# <li>Choose the values for `set_requires_grad`. Hint: which part of the code is training the generator? Which part is training the discriminator</li>
# <li>Choose the values of `set_requires_grad`, again. Hint: you may want to switch</li>
# <li>Choose the sign of the discriminator loss. Hint: what does the discriminator want to do?</li>
# </ul>
# Let's train the StarGAN one batch a time.
# While you watch the model train, consider whether you think it will be successful at generating counterfactuals in the number of steps we give it. What is the minimum number of iterations you think are needed for this to work, and how much time do yo uthink it will take?
# </div>
# %% tags=["task"]
from tqdm import tqdm # This is a nice library for showing progress bars


# %% [markdown] tags=[]
# ...this time again.
#
# <img src="assets/model_train.jpg" alt="drawing" width="500px"/>
#
# %%
# TODO also turn this into a standalone script for use during the project phase
from torch.utils.data import DataLoader
from tqdm import tqdm
losses = {"cycle": [], "adv": [], "disc": []}

for epoch in range(15):
for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
x = x.to(device)
y = y.to(device)
# get the target y by shuffling the classes
# get the style sources by random sampling
random_index = torch.randperm(len(y))
x_style = x[random_index].clone()
y_target = y[random_index].clone()

def set_requires_grad(module, value=True):
"""Sets `requires_grad` on a `module`'s parameters to `value`"""
for param in module.parameters():
param.requires_grad = value
# TODO - Choose an option by commenting out what you don't want
############
# Option 1 #
############
set_requires_grad(generator, True)
set_requires_grad(discriminator, False)
############
# Option 2 #
############
set_requires_grad(generator, False)
set_requires_grad(discriminator, True)

optimizer_g.zero_grad()
# Get the fake image
x_fake = generator(x, x_style)
# Try to cycle back
x_cycled = generator(x_fake, x)
# Discriminate
discriminator_x_fake = discriminator(x_fake)
# Losses to train the generator

cycle_loss_fn = nn.L1Loss()
class_loss_fn = nn.CrossEntropyLoss()
# 1. make sure the image can be reconstructed
cycle_loss = cycle_loss_fn(x, x_cycled)
# 2. make sure the discriminator is fooled
adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)

optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-6)
optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
# Optimize the generator
(cycle_loss + adv_loss).backward()
optimizer_g.step()

# TODO - Choose an option by commenting out what you don't want
############
# Option 1 #
############
set_requires_grad(generator, True)
set_requires_grad(discriminator, False)
############
# Option 2 #
############
set_requires_grad(generator, False)
set_requires_grad(discriminator, True)
#
optimizer_d.zero_grad()
#
discriminator_x = discriminator(x)
discriminator_x_fake = discriminator(x_fake.detach())

# TODO - Choose an option by commenting out what you don't want
# Losses to train the discriminator
# 1. make sure the discriminator can tell real is real
# 2. make sure the discriminator can tell fake is fake
############
# Option 1 #
############
real_loss = adversarial_loss_fn(discriminator_x, y)
fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)
############
# Option 2 #
############
real_loss = adversarial_loss_fn(discriminator_x, y)
fake_loss = adversarial_loss_fn(discriminator_x_fake, y_target)
#
disc_loss = (real_loss + fake_loss) * 0.5
disc_loss.backward()
# Optimize the discriminator
optimizer_d.step()

losses["cycle"].append(cycle_loss.item())
losses["adv"].append(adv_loss.item())
losses["disc"].append(disc_loss.item())

# %% tags=["solution"]
from tqdm import tqdm # This is a nice library for showing progress bars

dataloader = DataLoader(
mnist, batch_size=32, drop_last=True, shuffle=True
) # We will use the same dataset as before

losses = {"cycle": [], "adv": [], "disc": []}
for epoch in range(50):
for epoch in range(15):
for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
x = x.to(device)
y = y.to(device)
Expand All @@ -533,7 +621,7 @@ def set_requires_grad(module, value=True):
# 1. make sure the image can be reconstructed
cycle_loss = cycle_loss_fn(x, x_cycled)
# 2. make sure the discriminator is fooled
adv_loss = class_loss_fn(discriminator_x_fake, y_target)
adv_loss = adversarial_loss_fn(discriminator_x_fake, y_target)

# Optimize the generator
(cycle_loss + adv_loss).backward()
Expand All @@ -547,9 +635,9 @@ def set_requires_grad(module, value=True):
discriminator_x_fake = discriminator(x_fake.detach())
# Losses to train the discriminator
# 1. make sure the discriminator can tell real is real
real_loss = class_loss_fn(discriminator_x, y)
# 2. make sure the discriminator can't tell fake is fake
fake_loss = -class_loss_fn(discriminator_x_fake, y_target)
real_loss = adversarial_loss_fn(discriminator_x, y)
# 2. make sure the discriminator can tell fake is fake
fake_loss = -adversarial_loss_fn(discriminator_x_fake, y_target)
#
disc_loss = (real_loss + fake_loss) * 0.5
disc_loss.backward()
Expand All @@ -560,15 +648,20 @@ def set_requires_grad(module, value=True):
losses["adv"].append(adv_loss.item())
losses["disc"].append(disc_loss.item())


# %% [markdown] tags=[]
# ...this time again. &#x1F682; &#x1F68B; &#x1F68B; &#x1F68B;
#
# Once training is complete, we can plot the losses to see how well the model is doing.
# %%
plt.plot(losses["cycle"], label="Cycle loss")
plt.plot(losses["adv"], label="Adversarial loss")
plt.plot(losses["disc"], label="Discriminator loss")
plt.legend()
plt.show()
# %% [markdown] tags=[]
# Let's add a quick plotting function before we begin training...

# %% [markdown] tags=[]
# We can also look at some examples of the images that the generator is creating.
# %%
idx = 0
fig, axs = plt.subplots(1, 4, figsize=(12, 4))
Expand All @@ -581,8 +674,8 @@ def set_requires_grad(module, value=True):
ax.axis("off")
plt.show()

# TODO WIP here

# %%
# TODO wip here
# %% [markdown] tags=[]
# <div class="alert alert-block alert-success"><h2>Checkpoint 3</h2>
# You've now learned the basics of what makes up a CycleGAN, and details on how to perform adversarial training.
Expand Down

0 comments on commit b3d267d

Please sign in to comment.