Skip to content

Commit

Permalink
Merge pull request #2 from suzuki-2001/dev
Browse files Browse the repository at this point in the history
Add KL annealing, cutom loss, wandb tracking
  • Loading branch information
suzuki-2001 authored Nov 21, 2024
2 parents 3eb1e3a + 8dc7308 commit 64b885f
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 231 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ cython_debug/

data/
output/
wandb/
2 changes: 2 additions & 0 deletions env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies:
- pytorch>=2.5.1
- torchvision>=0.20.1
- pip:
- wandb
- pytz
- h5py
- loguru
- imageio
Expand Down
64 changes: 64 additions & 0 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from PIL import Image
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm


Expand Down Expand Up @@ -46,6 +47,69 @@ def get_config(cls, dataset_name: str) -> "DatasetConfig":
return configs[dataset_name]


def get_dataset(params, logger):
"""Load dataset with distributed support"""
dataset_classes = {
"mnist": MNIST,
"fashionmnist": FashionMNIST,
"shapes3d": Shapes3D,
"dsprites": DSprites,
"celeba": CelebA,
"flowers102": Flowers102,
"dtd": DTD,
"imagenet": ImageNet,
"mpi3d": MPI3D,
"ident3d": Ident3D,
}

if params.dataset not in dataset_classes:
raise ValueError(f"Unknown dataset: {params.dataset}")

dataset_class = dataset_classes[params.dataset]

try:
if params.dataset == "mpi3d":
variant = getattr(params, "mpi3d_variant", "toy")
dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant)
else:
dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4)

config = dataset.get_config()
params.chn_num = config.chn_num
params.image_size = config.image_size

train_loader, test_loader = dataset.get_data_loader()
if params.distributed:
train_sampler = DistributedSampler(
train_loader.dataset,
num_replicas=params.world_size,
rank=params.local_rank,
shuffle=True,
drop_last=True,
)

train_loader = torch.utils.data.DataLoader(
train_loader.dataset,
batch_size=params.batch_size,
sampler=train_sampler,
num_workers=params.num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=True,
)

if params.local_rank == 0:
logger.info(f"Dataset {params.dataset} loaded with distributed sampler")
else:
logger.info(f"Dataset {params.dataset} loaded")

return train_loader, test_loader

except Exception as e:
logger.error(f"Failed to load dataset: {str(e)}")
raise


def download_file(url, filename):
"""Download file with progress bar"""
try:
Expand Down
137 changes: 115 additions & 22 deletions src/provlae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
from math import ceil, log2

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -21,6 +23,16 @@ def __init__(
pre_kl=True,
coff=0.5,
train_seq=1,
use_kl_annealing=False,
kl_annealing_mode="linear",
cycle_period=4,
max_kl_weight=1,
min_kl_weight=0.1,
ratio=1.0,
use_capacity_increase=False,
gamma=1000.0,
max_capacity=25,
capacity_max_iter=1e-5,
):
super(ProVLAE, self).__init__()

Expand All @@ -44,6 +56,22 @@ def __init__(
self.fade_in_duration = fade_in_duration
self.train_seq = min(train_seq, self.num_ladders)

# for kl annealing
self.use_kl_annealing = use_kl_annealing
self.kl_annealing_mode = kl_annealing_mode
self.current_epoch = None
self.num_epochs = None
self.cycle_period = cycle_period
self.max_kl_weight = max_kl_weight
self.min_kl_weight = min_kl_weight
self.ratio = ratio

# Improving disentangling in β-VAE with controlled capacity increase
self.use_capacity_increase = use_capacity_increase
self.gamma = gamma
self.C_max = torch.Tensor([max_capacity])
self.C_stop_iter = capacity_max_iter

# Calculate encoder sizes
self.encoder_sizes = [self.target_size]
current_size = self.target_size
Expand All @@ -56,8 +84,7 @@ def __init__(
self.hidden_dims.extend([self.hidden_dims[-1]] * (self.num_ladders - len(self.hidden_dims)))
self.hidden_dims = self.hidden_dims[: self.num_ladders]

# Base setup
self.activation = nn.ELU() # or LeakyReLU
self.activation = nn.LeakyReLU() # or ELU
self.q_dist = Normal
self.x_dist = Bernoulli
self.prior_params = nn.Parameter(torch.zeros(self.z_dim, 2))
Expand Down Expand Up @@ -170,16 +197,70 @@ def _sample_latent(self, z_params):
return z_mean + eps * std, z_mean, z_log_var

def _kl_divergence(self, z_mean, z_log_var):
return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
return -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())

def fade_in_alpha(self, step):
if step > self.fade_in_duration:
return 1.0
return step / self.fade_in_duration

def frange_cycle_linear(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
L = np.ones(n_epoch)
period = n_epoch / n_cycle
step = (stop - start) / (period / ratio)

for c in range(n_cycle):
v, i = start, 0
while v <= stop and int(i + c * period) < n_epoch:
L[int(i + c * period)] = v
v += step
i += 1
return L

def frange_cycle_sigmoid(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
L = np.ones(n_epoch)
period = n_epoch / n_cycle
step = (stop - start) / (period * ratio) # step is in [0,1]

# transform into [-6, 6] for plots: v*12.-6.

for c in range(n_cycle):

v, i = start, 0
while v <= stop:
L[int(i + c * period)] = 1.0 / (1.0 + np.exp(-(v * 12.0 - 6.0)))
v += step
i += 1
return L

def frange_cycle_cosine(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
L = np.ones(n_epoch)
period = n_epoch / n_cycle
step = (stop - start) / (period * ratio) # step is in [0,1]

# transform into [0, pi] for plots:

for c in range(n_cycle):

v, i = start, 0
while v <= stop:
L[int(i + c * period)] = 0.5 - 0.5 * math.cos(v * math.pi)
v += step
i += 1
return L

def cycle_kl_weights(self, epoch, n_epoch, cycle_period=4, max_kl_weight=1.0, min_kl_weight=0.1, ratio=0.5):
if self.kl_annealing_mode == "linear":
kl_weights = self.frange_cycle_linear(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)
if self.kl_annealing_mode == "sigmoid":
kl_weights = self.frange_cycle_sigmoid(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)
if self.kl_annealing_mode == "cosine":
kl_weights = self.frange_cycle_cosine(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)

return kl_weights[epoch]

def encode(self, x):
# Store original size
original_size = x.size()[-2:]
original_size = x.size()[-2:] # Store original size

# Resize to target size
if original_size != (self.target_size, self.target_size):
Expand Down Expand Up @@ -222,10 +303,8 @@ def decode(self, z_list, original_size):
f = f * self.fade_in
features.append(f)

# Start from deepest layer
x = features[-1]

# Progressive decoding with explicit size management
x = features[-1] # Start from deepest layer
for i in range(self.num_ladders - 2, -1, -1):
# Ensure feature maps have matching spatial dimensions
target_size = features[i].size(-1)
Expand All @@ -241,8 +320,7 @@ def decode(self, z_list, original_size):
for up_layer in self.additional_ups:
x = up_layer(x)

# Final convolution
x = self.output_layer(x)
x = self.output_layer(x) # Final convolution

# Resize to original input size
if original_size != (x.size(-2), x.size(-1)):
Expand All @@ -252,35 +330,50 @@ def decode(self, z_list, original_size):

def forward(self, x, step=0):
self.fade_in = self.fade_in_alpha(step)
kl_weight = self.cycle_kl_weights(
epoch=self.current_epoch,
n_epoch=self.num_epochs,
cycle_period=self.cycle_period,
max_kl_weight=self.max_kl_weight,
min_kl_weight=self.min_kl_weight,
ratio=self.ratio,
)

# Encode
z_params, original_size = self.encode(x)

# Calculate KL divergence
latent_losses = []
zs = []
for z, z_mean, z_log_var in z_params:
latent_losses.append(self._kl_divergence(z_mean, z_log_var))
zs.append(z)

latent_loss = sum(latent_losses)

# Decode
x_recon = self.decode(zs, original_size)

# Reconstruction loss
bce_loss = nn.BCEWithLogitsLoss(reduction="sum")
bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
recon_loss = bce_loss(x_recon, x)

# Calculate final loss
if self.pre_kl:
active_latents = latent_losses[self.train_seq - 1 :]
inactive_latents = latent_losses[: self.train_seq - 1]
loss = recon_loss + self.beta * sum(active_latents) + self.coff * sum(inactive_latents)
# prekl loss
active_latents = latent_losses[self.train_seq - 1 :]
inactive_latents = latent_losses[: self.train_seq - 1]
if self.use_kl_annealing:
if self.use_capacity_increase:
# https://arxiv.org/pdf/1804.03599.pdf
self.C_max = self.C_max.to(x.device)
C = torch.clamp(self.C_max / self.C_stop_iter * step, 0, self.C_max.data[0])
kl_term = self.gamma * kl_weight * (sum(active_latents) - C).abs()
else:
# https://openreview.net/forum?id=Sy2fzU9gl
kl_term = kl_weight * self.beta * sum(active_latents)
else:
loss = recon_loss + self.beta * latent_loss
kl_term = self.beta * sum(active_latents)

loss = recon_loss + kl_term
if self.pre_kl:
loss += self.coff * sum(inactive_latents)

return torch.sigmoid(x_recon), loss, latent_loss, recon_loss
return torch.sigmoid(x_recon), loss, kl_term, recon_loss, kl_weight

def inference(self, x):
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/run_dsprites.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ torchrun --nproc_per_node=2 --master_port=29502 src/train.py \
--batch_size 256 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 8 \
--beta 3 \
--z_dim 2 \
--coff 0.5 \
--pre_kl \
Expand Down
8 changes: 4 additions & 4 deletions src/scripts/run_imagenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--dataset imagenet \
--optim adamw \
--num_ladders 4 \
--batch_size 256 \
--num_epochs 100 \
--batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 1 \
--z_dim 4 \
--coff 0.5 \
--z_dim 8 \
--coff 0.1 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
Expand Down
26 changes: 18 additions & 8 deletions src/scripts/run_mnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,29 @@

torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset mnist \
--mode indep_train \
--train_seq 3 \
--dataset shapes3d \
--optim adamw \
--num_ladders 3 \
--batch_size 64 \
--num_epochs 30 \
--batch_size 128 \
--num_epochs 16 \
--learning_rate 5e-4 \
--beta 3 \
--z_dim 2 \
--beta 1 \
--z_dim 3 \
--coff 0.5 \
--pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/mnist/ \
--data_path ./data/mnist/
--output_dir ./output/shapes3d/ \
--data_path ./data/ \
--use_wandb \
--wandb_project PRO-VLAE \
--use_kl_annealing \
--kl_annealing_mode sigmoid \
--cycle_period 4 \
--ratio 0.5 \
--max_kl_weight 1.0 \
--min_kl_weight 0.1 \
--num_workers 16

Loading

0 comments on commit 64b885f

Please sign in to comment.