Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul exercise for 2024 course #11

Merged
merged 38 commits into from
Aug 21, 2024
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
04081c0
Begin change to Colored MNIST
adjavon Jul 25, 2024
84f0a44
Update README overview
adjavon Jul 25, 2024
354005a
wip: Add GAN script
adjavon Jul 25, 2024
14d8e72
wip: Update tasks, parts 1-3
adjavon Jul 25, 2024
4232d59
Add workflow for building notebooks
adjavon Jul 25, 2024
bfc68ba
Commit from GitHub Actions (Build Notebooks)
adjavon Jul 25, 2024
f43a1f5
Clean up tags for parts 1 and 2
adjavon Jul 25, 2024
690d2d0
Commit from GitHub Actions (Build Notebooks)
adjavon Jul 25, 2024
ecef44d
Add EMA to UNet and validate GAN
adjavon Jul 28, 2024
7afaef3
Restart training from checkpoint
adjavon Aug 6, 2024
4fc7a43
Add stargan figure
adjavon Aug 6, 2024
f12e6d8
Reduce hard-coding in viewing results
adjavon Aug 6, 2024
343e364
wip: Add explanations about the GAN trainig
adjavon Aug 6, 2024
3d887bc
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 6, 2024
b3d267d
wip: Add GAN training task
adjavon Aug 12, 2024
3327629
wip: Begin evaluation of the counterfactuals using classifier
adjavon Aug 12, 2024
03e6aaa
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 12, 2024
5e963df
wip: Add EMA to GAN training
adjavon Aug 12, 2024
846525f
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 12, 2024
702c0e3
wip: Add discriminative attribution
adjavon Aug 12, 2024
b4595ab
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 12, 2024
f864649
Finish style space, explanations, and conclusion
adjavon Aug 15, 2024
33a6110
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 15, 2024
c1a6e28
Fix numbering, missing todos, and plotting bug
adjavon Aug 15, 2024
544c6a7
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 15, 2024
559ccf9
Update setup script
adjavon Aug 16, 2024
14d5975
Merge branch '2024' of github.com:dlmbl/knowledge_extraction into 2024
adjavon Aug 16, 2024
5ccd575
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 16, 2024
12a6ff9
Fix enviroment creation script
afoix Aug 17, 2024
81652c5
update exercise number in the README.md
afoix Aug 17, 2024
4921a76
Commit from GitHub Actions (Build Notebooks)
afoix Aug 17, 2024
2599651
Ben/review (#12)
Ben-Salmon Aug 20, 2024
d759e63
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 20, 2024
83495ec
Fix exercise setup
adjavon Aug 20, 2024
7d92477
Move data to extras
adjavon Aug 20, 2024
b454546
Split loss plot
adjavon Aug 20, 2024
81751d2
Update README
adjavon Aug 20, 2024
0fca9ec
Commit from GitHub Actions (Build Notebooks)
adjavon Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add EMA to UNet and validate GAN
  • Loading branch information
adjavon committed Jul 28, 2024
commit ecef44dc7d02c7c06b5dfccbd71d62e3341430e7
47 changes: 42 additions & 5 deletions extras/train_gan.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from copy import deepcopy
import json
from pathlib import Path

class Generator(nn.Module):
def __init__(self, generator, style_mapping):
@@ -34,16 +36,34 @@ def set_requires_grad(module, value=True):
param.requires_grad = value


def exponential_moving_average(model, ema_model, beta=0.999):
"""Update the EMA model's parameters with an exponential moving average"""
for param, ema_param in zip(model.parameters(), ema_model.parameters()):
ema_param.data.mul_(beta).add_((1 - beta) * param.data)


def copy_parameters(source_model, target_model):
"""Copy the parameters of a model to another model"""
for param, target_param in zip(
source_model.parameters(), target_model.parameters()
):
target_param.data.copy_(param.data)


if __name__ == "__main__":
save_dir = Path("checkpoints/stargan")
save_dir.mkdir(parents=True, exist_ok=True)
mnist = ColoredMNIST("../data", download=True, train=True)
device = torch.devic("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())
unet_ema = deepcopy(unet)
discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)
style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)
generator = Generator(unet, style_mapping=style_mapping)

# all models on the GPU
generator = generator.to(device)
unet_ema = unet_ema.to(device)
discriminator = discriminator.to(device)

cycle_loss_fn = nn.L1Loss()
@@ -57,7 +77,7 @@ def set_requires_grad(module, value=True):
) # We will use the same dataset as before

losses = {"cycle": [], "adv": [], "disc": []}
for epoch in range(50):
for epoch in range(25):
for x, y in tqdm(dataloader, desc=f"Epoch {epoch}"):
x = x.to(device)
y = y.to(device)
@@ -110,6 +130,23 @@ def set_requires_grad(module, value=True):
losses["adv"].append(adv_loss.item())
losses["disc"].append(disc_loss.item())

# EMA update
exponential_moving_average(unet, unet_ema)
# TODO add logging, add checkpointing

# TODO store losses
# Copy the EMA model's parameters to the generator
copy_parameters(unet_ema, unet)
# Store checkpoint
torch.save(
{
"unet": unet.state_dict(),
"discriminator": discriminator.state_dict(),
"style_mapping": style_mapping.state_dict(),
"optimizer_g": optimizer_g.state_dict(),
"optimizer_d": optimizer_d.state_dict(),
"epoch": epoch,
},
save_dir / f"checkpoint_{epoch}.pth",
)
# Store losses
with open(save_dir / "losses.json", "w") as f:
json.dump(losses, f)
69 changes: 69 additions & 0 deletions extras/validate_gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# %%
from dlmbl_unet import UNet
from classifier.model import DenseModel
from classifier.data import ColoredMNIST
import torch
from torch import nn
import json
from pathlib import Path
from matplotlib import pyplot as plt
import numpy as np
from train_gan import Generator

# %%
with open("checkpoints/stargan/losses.json", "r") as f:
losses = json.load(f)

for key, value in losses.items():
plt.plot(value, label=key)
plt.legend()

# %%
# Create the model
unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())
style_encoder = DenseModel(input_shape=(3, 28, 28), num_classes=3)
# Load model weights
weights = torch.load("checkpoints/stargan/checkpoint_25.pth")
unet.load_state_dict(weights["unet"])
style_encoder.load_state_dict(weights["style_mapping"]) # Change this to style encoder
generator = Generator(unet, style_encoder)

# %% Plotting an example
# Load the data
mnist = ColoredMNIST("../data", download=True, train=False)

# Load one image from the dataset
x, y = mnist[0]
# Load one image from each other class
results = {}
for i in range(len(mnist.classes)):
if i == y:
continue
index = np.where(mnist.targets == i)[0][0]
style = mnist[index][0]
# Generate the images
generated = generator(x.unsqueeze(0), style.unsqueeze(0))
results[i] = (style, generated)
# %%
# Plot the images
source_style = mnist.classes[y]

fig, axes = plt.subplots(2, 4, figsize=(12, 3))
for i, (style, generated) in results.items():
axes[0, i].imshow(style.permute(1, 2, 0))
axes[0, i].set_title(mnist.classes[i])
axes[0, i].axis("off")
axes[1, i].imshow(generated[0].detach().permute(1, 2, 0))
axes[1, i].set_title(f"{mnist.classes[i]}")
axes[1, i].axis("off")

# Plot real
axes[1, y].imshow(x.permute(1, 2, 0))
axes[1, y].set_title(source_style)
axes[1, y].axis("off")
axes[0, y].axis("off")

# %%
# TODO get prototype images for each class
# TODO convert every image in the dataset + classify result
# TODO plot a confusion matrix