diff --git a/src/dataset.py b/src/dataset.py
index 904e6d4..5ca8d26 100644
--- a/src/dataset.py
+++ b/src/dataset.py
@@ -13,6 +13,7 @@
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
+from torch.utils.data.distributed import DistributedSampler
@dataclass
@@ -46,6 +47,69 @@ def get_config(cls, dataset_name: str) -> "DatasetConfig":
return configs[dataset_name]
+def get_dataset(params, logger):
+ """Load dataset with distributed support"""
+ dataset_classes = {
+ "mnist": MNIST,
+ "fashionmnist": FashionMNIST,
+ "shapes3d": Shapes3D,
+ "dsprites": DSprites,
+ "celeba": CelebA,
+ "flowers102": Flowers102,
+ "dtd": DTD,
+ "imagenet": ImageNet,
+ "mpi3d": MPI3D,
+ "ident3d": Ident3D,
+ }
+
+ if params.dataset not in dataset_classes:
+ raise ValueError(f"Unknown dataset: {params.dataset}")
+
+ dataset_class = dataset_classes[params.dataset]
+
+ try:
+ if params.dataset == "mpi3d":
+ variant = getattr(params, "mpi3d_variant", "toy")
+ dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant)
+ else:
+ dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4)
+
+ config = dataset.get_config()
+ params.chn_num = config.chn_num
+ params.image_size = config.image_size
+
+ train_loader, test_loader = dataset.get_data_loader()
+ if params.distributed:
+ train_sampler = DistributedSampler(
+ train_loader.dataset,
+ num_replicas=params.world_size,
+ rank=params.local_rank,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ train_loader = torch.utils.data.DataLoader(
+ train_loader.dataset,
+ batch_size=params.batch_size,
+ sampler=train_sampler,
+ num_workers=params.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ persistent_workers=True,
+ )
+
+ if params.local_rank == 0:
+ logger.info(f"Dataset {params.dataset} loaded with distributed sampler")
+ else:
+ logger.info(f"Dataset {params.dataset} loaded")
+
+ return train_loader, test_loader
+
+ except Exception as e:
+ logger.error(f"Failed to load dataset: {str(e)}")
+ raise
+
+
def download_file(url, filename):
"""Download file with progress bar"""
try:
diff --git a/src/provlae.py b/src/provlae.py
index c0abec5..ff00078 100644
--- a/src/provlae.py
+++ b/src/provlae.py
@@ -1,4 +1,6 @@
from math import ceil, log2
+import numpy as np
+import math
import torch
import torch.nn as nn
@@ -21,6 +23,16 @@ def __init__(
pre_kl=True,
coff=0.5,
train_seq=1,
+ use_kl_annealing=False,
+ kl_annealing_mode="linear",
+ cycle_period=4,
+ max_kl_weight=1,
+ min_kl_weight=0.1,
+ ratio=1.0,
+ use_capacity_increase=False,
+ gamma=1000.0,
+ max_capacity=25,
+ capacity_max_iter=1e-5,
):
super(ProVLAE, self).__init__()
@@ -44,6 +56,22 @@ def __init__(
self.fade_in_duration = fade_in_duration
self.train_seq = min(train_seq, self.num_ladders)
+ # for kl annealing
+ self.use_kl_annealing = use_kl_annealing
+ self.kl_annealing_mode = kl_annealing_mode
+ self.current_epoch = None
+ self.num_epochs = None
+ self.cycle_period = cycle_period
+ self.max_kl_weight = max_kl_weight
+ self.min_kl_weight = min_kl_weight
+ self.ratio = ratio
+
+ # Improving disentangling in β-VAE with controlled capacity increase
+ self.use_capacity_increase = use_capacity_increase
+ self.gamma = gamma
+ self.C_max = torch.Tensor([max_capacity])
+ self.C_stop_iter = capacity_max_iter
+
# Calculate encoder sizes
self.encoder_sizes = [self.target_size]
current_size = self.target_size
@@ -56,8 +84,7 @@ def __init__(
self.hidden_dims.extend([self.hidden_dims[-1]] * (self.num_ladders - len(self.hidden_dims)))
self.hidden_dims = self.hidden_dims[: self.num_ladders]
- # Base setup
- self.activation = nn.ELU() # or LeakyReLU
+ self.activation = nn.LeakyReLU() # or ELU
self.q_dist = Normal
self.x_dist = Bernoulli
self.prior_params = nn.Parameter(torch.zeros(self.z_dim, 2))
@@ -170,16 +197,70 @@ def _sample_latent(self, z_params):
return z_mean + eps * std, z_mean, z_log_var
def _kl_divergence(self, z_mean, z_log_var):
- return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
+ return -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
def fade_in_alpha(self, step):
if step > self.fade_in_duration:
return 1.0
return step / self.fade_in_duration
+ def frange_cycle_linear(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
+ L = np.ones(n_epoch)
+ period = n_epoch / n_cycle
+ step = (stop - start) / (period / ratio)
+
+ for c in range(n_cycle):
+ v, i = start, 0
+ while v <= stop and int(i + c * period) < n_epoch:
+ L[int(i + c * period)] = v
+ v += step
+ i += 1
+ return L
+
+ def frange_cycle_sigmoid(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
+ L = np.ones(n_epoch)
+ period = n_epoch / n_cycle
+ step = (stop - start) / (period * ratio) # step is in [0,1]
+
+ # transform into [-6, 6] for plots: v*12.-6.
+
+ for c in range(n_cycle):
+
+ v, i = start, 0
+ while v <= stop:
+ L[int(i + c * period)] = 1.0 / (1.0 + np.exp(-(v * 12.0 - 6.0)))
+ v += step
+ i += 1
+ return L
+
+ def frange_cycle_cosine(self, start, stop, n_epoch, n_cycle=4, ratio=0.5):
+ L = np.ones(n_epoch)
+ period = n_epoch / n_cycle
+ step = (stop - start) / (period * ratio) # step is in [0,1]
+
+ # transform into [0, pi] for plots:
+
+ for c in range(n_cycle):
+
+ v, i = start, 0
+ while v <= stop:
+ L[int(i + c * period)] = 0.5 - 0.5 * math.cos(v * math.pi)
+ v += step
+ i += 1
+ return L
+
+ def cycle_kl_weights(self, epoch, n_epoch, cycle_period=4, max_kl_weight=1.0, min_kl_weight=0.1, ratio=0.5):
+ if self.kl_annealing_mode == "linear":
+ kl_weights = self.frange_cycle_linear(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)
+ if self.kl_annealing_mode == "sigmoid":
+ kl_weights = self.frange_cycle_sigmoid(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)
+ if self.kl_annealing_mode == "cosine":
+ kl_weights = self.frange_cycle_cosine(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio)
+
+ return kl_weights[epoch]
+
def encode(self, x):
- # Store original size
- original_size = x.size()[-2:]
+ original_size = x.size()[-2:] # Store original size
# Resize to target size
if original_size != (self.target_size, self.target_size):
@@ -222,10 +303,8 @@ def decode(self, z_list, original_size):
f = f * self.fade_in
features.append(f)
- # Start from deepest layer
- x = features[-1]
-
# Progressive decoding with explicit size management
+ x = features[-1] # Start from deepest layer
for i in range(self.num_ladders - 2, -1, -1):
# Ensure feature maps have matching spatial dimensions
target_size = features[i].size(-1)
@@ -241,8 +320,7 @@ def decode(self, z_list, original_size):
for up_layer in self.additional_ups:
x = up_layer(x)
- # Final convolution
- x = self.output_layer(x)
+ x = self.output_layer(x) # Final convolution
# Resize to original input size
if original_size != (x.size(-2), x.size(-1)):
@@ -252,11 +330,17 @@ def decode(self, z_list, original_size):
def forward(self, x, step=0):
self.fade_in = self.fade_in_alpha(step)
+ kl_weight = self.cycle_kl_weights(
+ epoch=self.current_epoch,
+ n_epoch=self.num_epochs,
+ cycle_period=self.cycle_period,
+ max_kl_weight=self.max_kl_weight,
+ min_kl_weight=self.min_kl_weight,
+ ratio=self.ratio,
+ )
- # Encode
z_params, original_size = self.encode(x)
- # Calculate KL divergence
latent_losses = []
zs = []
for z, z_mean, z_log_var in z_params:
@@ -264,23 +348,32 @@ def forward(self, x, step=0):
zs.append(z)
latent_loss = sum(latent_losses)
-
- # Decode
x_recon = self.decode(zs, original_size)
# Reconstruction loss
- bce_loss = nn.BCEWithLogitsLoss(reduction="sum")
+ bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
recon_loss = bce_loss(x_recon, x)
- # Calculate final loss
- if self.pre_kl:
- active_latents = latent_losses[self.train_seq - 1 :]
- inactive_latents = latent_losses[: self.train_seq - 1]
- loss = recon_loss + self.beta * sum(active_latents) + self.coff * sum(inactive_latents)
+ # prekl loss
+ active_latents = latent_losses[self.train_seq - 1 :]
+ inactive_latents = latent_losses[: self.train_seq - 1]
+ if self.use_kl_annealing:
+ if self.use_capacity_increase:
+ # https://arxiv.org/pdf/1804.03599.pdf
+ self.C_max = self.C_max.to(x.device)
+ C = torch.clamp(self.C_max / self.C_stop_iter * step, 0, self.C_max.data[0])
+ kl_term = self.gamma * kl_weight * (sum(active_latents) - C).abs()
+ else:
+ # https://openreview.net/forum?id=Sy2fzU9gl
+ kl_term = kl_weight * self.beta * sum(active_latents)
else:
- loss = recon_loss + self.beta * latent_loss
+ kl_term = self.beta * sum(active_latents)
+
+ loss = recon_loss + kl_term
+ if self.pre_kl:
+ loss += self.coff * sum(inactive_latents)
- return torch.sigmoid(x_recon), loss, latent_loss, recon_loss
+ return torch.sigmoid(x_recon), loss, kl_term, recon_loss, kl_weight
def inference(self, x):
with torch.no_grad():
diff --git a/src/scripts/run_dsprites.sh b/src/scripts/run_dsprites.sh
index 027eedb..33cee2e 100644
--- a/src/scripts/run_dsprites.sh
+++ b/src/scripts/run_dsprites.sh
@@ -9,7 +9,7 @@ torchrun --nproc_per_node=2 --master_port=29502 src/train.py \
--batch_size 256 \
--num_epochs 30 \
--learning_rate 5e-4 \
- --beta 8 \
+ --beta 3 \
--z_dim 2 \
--coff 0.5 \
--pre_kl \
diff --git a/src/scripts/run_imagenet.sh b/src/scripts/run_imagenet.sh
index cf11d39..d1bf44c 100644
--- a/src/scripts/run_imagenet.sh
+++ b/src/scripts/run_imagenet.sh
@@ -6,12 +6,12 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--dataset imagenet \
--optim adamw \
--num_ladders 4 \
- --batch_size 256 \
- --num_epochs 100 \
+ --batch_size 128 \
+ --num_epochs 30 \
--learning_rate 5e-4 \
--beta 1 \
- --z_dim 4 \
- --coff 0.5 \
+ --z_dim 8 \
+ --coff 0.1 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
diff --git a/src/scripts/run_mnist.sh b/src/scripts/run_mnist.sh
index bd651ac..df9c01c 100644
--- a/src/scripts/run_mnist.sh
+++ b/src/scripts/run_mnist.sh
@@ -2,21 +2,29 @@
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
- --mode seq_train \
- --dataset mnist \
+ --mode indep_train \
+ --train_seq 3 \
+ --dataset shapes3d \
--optim adamw \
--num_ladders 3 \
- --batch_size 64 \
- --num_epochs 5 \
+ --batch_size 128 \
+ --num_epochs 16 \
--learning_rate 5e-4 \
- --beta 3 \
- --z_dim 2 \
+ --beta 1 \
+ --z_dim 3 \
--coff 0.5 \
--pre_kl \
- --hidden_dim 64 \
+ --hidden_dim 32 \
--fade_in_duration 5000 \
- --output_dir ./output/mnist/ \
- --data_path ./data/mnist/ \
+ --output_dir ./output/shapes3d/ \
+ --data_path ./data/ \
--use_wandb \
--wandb_project PRO-VLAE \
+ --use_kl_annealing \
+ --kl_annealing_mode sigmoid \
+ --cycle_period 4 \
+ --ratio 0.5 \
+ --max_kl_weight 1.0 \
+ --min_kl_weight 0.1 \
+ --num_workers 16
\ No newline at end of file
diff --git a/src/train.py b/src/train.py
index 39d506a..1a06a08 100644
--- a/src/train.py
+++ b/src/train.py
@@ -1,27 +1,31 @@
import argparse
import os
-import sys
from dataclasses import dataclass, field
+import math
import imageio.v3 as imageio
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
-import torch.optim as optim
-import torch_optimizer as jettify_optim
import torchvision
import wandb
-from loguru import logger
from PIL import Image, ImageDraw, ImageFont
from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
-from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D
-from ddp_utils import cleanup_distributed, setup_distributed, setup_logger
from provlae import ProVLAE
-from utils import add_dataclass_args, exec_time
+from dataset import get_dataset
+from ddp_utils import cleanup_distributed, setup_distributed, setup_logger
+from utils import (
+ init_wandb,
+ get_optimizer,
+ add_dataclass_args,
+ exec_time,
+ save_input_image,
+ save_reconstruction,
+ load_checkpoint,
+)
@dataclass
@@ -50,6 +54,16 @@ class HyperParameters:
hidden_dim: int = field(default=32)
coff: float = field(default=0.5)
pre_kl: bool = field(default=True)
+ use_kl_annealing: bool = field(default=False)
+ kl_annealing_mode: str = field(default="linear")
+ cycle_period: int = field(default=4)
+ max_kl_weight: float = field(default=1.0)
+ min_kl_weight: float = field(default=0.1)
+ ratio: float = field(default=1.0)
+ use_capacity_increase: bool = field(default=False)
+ gamma: float = field(default=1000.0)
+ max_capacity: int = field(default=25)
+ capacity_max_iter: float = field(default=1e-5)
@dataclass
@@ -91,145 +105,6 @@ def parse_arguments():
return parser.parse_args()
-def get_dataset(params, logger):
- """Load dataset with distributed support"""
- dataset_classes = {
- "mnist": MNIST,
- "fashionmnist": FashionMNIST,
- "shapes3d": Shapes3D,
- "dsprites": DSprites,
- "celeba": CelebA,
- "flowers102": Flowers102,
- "dtd": DTD,
- "imagenet": ImageNet,
- "mpi3d": MPI3D,
- "ident3d": Ident3D,
- }
-
- if params.dataset not in dataset_classes:
- raise ValueError(f"Unknown dataset: {params.dataset}")
-
- dataset_class = dataset_classes[params.dataset]
-
- try:
- if params.dataset == "mpi3d":
- variant = getattr(params, "mpi3d_variant", "toy")
- dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant)
- else:
- dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4)
-
- config = dataset.get_config()
- params.chn_num = config.chn_num
- params.image_size = config.image_size
-
- train_loader, test_loader = dataset.get_data_loader()
- if params.distributed:
- train_sampler = DistributedSampler(
- train_loader.dataset,
- num_replicas=params.world_size,
- rank=params.local_rank,
- shuffle=True,
- drop_last=True,
- )
-
- train_loader = torch.utils.data.DataLoader(
- train_loader.dataset,
- batch_size=params.batch_size,
- sampler=train_sampler,
- num_workers=params.num_workers,
- pin_memory=True,
- drop_last=True,
- persistent_workers=True,
- )
-
- if params.local_rank == 0:
- logger.info(f"Dataset {params.dataset} loaded with distributed sampler")
- else:
- logger.info(f"Dataset {params.dataset} loaded")
-
- return train_loader, test_loader
-
- except Exception as e:
- logger.error(f"Failed to load dataset: {str(e)}")
- raise
-
-
-def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger):
- """Load a model checkpoint with proper device management."""
- try:
- checkpoint = torch.load(
- checkpoint_path,
- map_location=device,
- weights_only=True,
- )
-
- # Load model state dict
- if hasattr(model, "module"):
- model.module.load_state_dict(checkpoint["model_state_dict"])
- else:
- model.load_state_dict(checkpoint["model_state_dict"], strict=False)
-
- # Load optimizer state dict
- for state in optimizer.state.values():
- for k, v in state.items():
- if isinstance(v, torch.Tensor):
- state[k] = v.to(device)
-
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
-
- if scaler is not None and "scaler_state_dict" in checkpoint:
- scaler.load_state_dict(checkpoint["scaler_state_dict"])
-
- logger.info(
- f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})"
- )
-
- return model, optimizer, scaler
- except Exception as e:
- logger.error(f"Failed to load checkpoint: {str(e)}")
- return model, optimizer, scaler
-
-
-def save_reconstruction(inputs, reconstructions, save_path):
- """Save a grid of original and reconstructed images"""
- batch_size = min(8, inputs.shape[0])
- inputs = inputs[:batch_size].float()
- reconstructions = reconstructions[:batch_size].float()
- comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]])
-
- # Denormalize and convert to numpy
- images = comparison.cpu().detach()
- images = torch.clamp(images, 0, 1)
- grid = torchvision.utils.make_grid(images, nrow=batch_size)
- image = grid.permute(1, 2, 0).numpy()
-
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- imageio.imwrite(save_path, (image * 255).astype("uint8"))
-
-
-def save_input_image(inputs: torch.Tensor, save_dir: str, seq: int, size: int = 96) -> str:
- input_path = os.path.join(save_dir, f"traverse_input_seq{seq}.png")
- os.makedirs(save_dir, exist_ok=True)
-
- input_img = inputs[0].cpu().float()
- input_img = torch.clamp(input_img, 0, 1)
-
- if input_img.shape[-1] != size:
- input_img = F.interpolate(
- input_img.unsqueeze(0),
- size=size,
- mode="bilinear",
- align_corners=False,
- ).squeeze(0)
-
- if input_img.shape[0] == 1:
- input_img = input_img.repeat(3, 1, 1)
-
- input_array = (input_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
- imageio.imwrite(input_path, input_array)
- return input_path
-
-
def create_latent_traversal(model, data_loader, save_path, device, params):
"""Create and save organized latent traversal GIF with optimized layout"""
model.eval()
@@ -242,13 +117,16 @@ def create_latent_traversal(model, data_loader, save_path, device, params):
inputs, _ = next(iter(data_loader)) # Get a single batch of images
inputs = inputs[0:1].to(device)
- input_path = save_input_image(inputs.cpu(), os.path.join(params.output_dir, params.input_dir), params.train_seq)
+ # save traverse inputs
+ input_path = save_input_image(
+ inputs.cpu(), os.path.join(params.output_dir, params.input_dir), params.train_seq, params.image_size
+ )
# Get latent representations
with torch.amp.autocast(device_type="cuda", enabled=False):
latent_vars = [z[0] for z in model.inference(inputs)]
- traverse_range = torch.linspace(-1.5, 1.5, 15).to(device)
+ traverse_range = torch.linspace(-2.5, 2.5, 10).to(device)
# Image layout parameters
img_size = 96 # Base image size
@@ -356,6 +234,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
if hasattr(model, "module"):
model.module.to(device)
+ model.module.num_epochs = params.num_epochs
else:
model.to(device)
@@ -364,6 +243,11 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
logger.info(f"Start training [progress {params.train_seq}]")
for epoch in range(params.num_epochs):
+ if hasattr(model, "module"):
+ model.module.current_epoch = epoch
+ else:
+ model.current_epoch = epoch
+
if params.distributed:
data_loader.sampler.set_epoch(epoch)
@@ -378,7 +262,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
inputs = inputs.to(device, non_blocking=True)
with torch.amp.autocast(device_type="cuda", dtype=autocast_dtype):
- x_recon, loss, latent_loss, recon_loss = model(inputs, step=global_step)
+ x_recon, loss, latent_loss, recon_loss, kl_weight = model(inputs, step=global_step)
optimizer.zero_grad()
if scaler is not None:
@@ -391,16 +275,17 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
if params.local_rank == 0 or not params.distributed: # Only show progress on main process
pbar.set_postfix(
- total_loss=f"{loss.item():.2f}",
- latent_loss=f"{latent_loss:.2f}",
- recon_loss=f"{recon_loss:.2f}",
+ total_loss=f"{loss.item():.5f}",
+ latent_loss=f"{latent_loss.item():.5f}",
+ recon_loss=f"{recon_loss.item():.5f}",
)
if params.use_wandb and params.distributed:
metrics = {
- f"loss/rank_{params.local_rank}": loss.item(),
- f"latent_loss/rank_{params.local_rank}": latent_loss.item(),
- f"recon_loss/rank_{params.local_rank}": recon_loss.item(),
+ f"ELBO/rank_{params.local_rank}": loss.item(),
+ f"KL Term/rank_{params.local_rank}": latent_loss.item(),
+ f"Reconstruction Error/rank_{params.local_rank}": recon_loss.item(),
+ f"KL Weight/rank_{params.local_rank}": kl_weight.item(),
}
all_metrics = [None] * params.world_size
@@ -412,7 +297,12 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
combined_metrics.update(rank_metrics)
wandb.log(combined_metrics, step=global_step)
elif params.use_wandb:
- metrics = {"loss": loss.item(), "latent_loss": latent_loss.item(), "recon_loss": recon_loss.item()}
+ metrics = {
+ "ELBO": loss.item(),
+ "KL Term": latent_loss.item(),
+ "Reconstruction Error": recon_loss.item(),
+ "KL Weight": kl_weight.item(),
+ }
wandb.log(metrics, step=global_step)
global_step += 1
@@ -432,7 +322,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
save_reconstruction(inputs, x_recon, recon_path)
input_path = create_latent_traversal(model, data_loader, traverse_path, device, params)
- # reconstruction and traversal images
+ # reconstruction and traversal images (Media)
if params.use_wandb and (params.local_rank == 0 or not params.distributed):
wandb.log(
{
@@ -458,100 +348,17 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}")
-def get_optimizer(model, params):
- """Get the optimizer based on the parameter settings"""
- optimizer_params = {
- "params": model.parameters(),
- "lr": params.learning_rate,
- }
-
- # Adam, Lamb, DiffGrad
- extra_args_common = {
- "betas": getattr(params, "betas", (0.9, 0.999)),
- "eps": getattr(params, "eps", 1e-8),
- "weight_decay": getattr(params, "weight_decay", 0),
- }
-
- extra_args_adamw = {
- "betas": getattr(params, "betas", (0.9, 0.999)),
- "eps": getattr(params, "eps", 1e-8),
- "weight_decay": getattr(params, "weight_decay", 0.01),
- }
-
- # SGD
- extra_args_sgd = {
- "momentum": getattr(params, "momentum", 0),
- "dampening": getattr(params, "dampening", 0),
- "weight_decay": getattr(params, "weight_decay", 0),
- "nesterov": getattr(params, "nesterov", False),
- }
-
- # MADGRAD
- extra_args_madgrad = {
- "momentum": getattr(params, "momentum", 0.9),
- "weight_decay": getattr(params, "weight_decay", 0),
- "eps": getattr(params, "eps", 1e-6),
- }
-
- optimizers = {
- "adam": (optim.Adam, extra_args_common),
- "adamw": (optim.AdamW, extra_args_adamw),
- "sgd": (optim.SGD, extra_args_sgd),
- "lamb": (jettify_optim.Lamb, extra_args_common),
- "diffgrad": (jettify_optim.DiffGrad, extra_args_common),
- "madgrad": (jettify_optim.MADGRAD, extra_args_madgrad),
- }
-
- optimizer_cls, extra_args = optimizers.get(params.optim.lower(), (optim.Adam, extra_args_common))
- if params.optim.lower() not in optimizers:
- logger.warning(f"Unsupported optimizer '{params.optim}', using 'Adam' optimizer instead.")
- optimizer = optimizer_cls(**optimizer_params, **extra_args)
-
- return optimizer
-
-
-def init_wandb(params, hash):
- if params.use_wandb:
- if wandb.run is not None:
- wandb.finish()
-
- run_id = None
- if params.local_rank == 0:
- logger.debug(f"Current run ID: {hash}")
- wandb.init(
- project=params.wandb_project,
- config=vars(params),
- name=f"{params.dataset.upper()}_PROGRESS{params.train_seq}_{hash}",
- settings=wandb.Settings(start_method="thread", _disable_stats=True),
- )
- run_id = wandb.run.id
-
- if params.distributed:
- object_list = [run_id if params.local_rank == 0 else None]
- dist.broadcast_object_list(object_list, src=0)
- run_id = object_list[0]
-
- if params.local_rank != 0:
- wandb.init(
- project=params.wandb_project,
- id=run_id,
- resume="allow",
- settings=wandb.Settings(start_method="thread", _disable_stats=True),
- )
-
-
def main():
- params = parse_arguments()
+ params = parse_arguments() # hyperparameter and training config
+
+ """TODO: fix random seed"""
try:
- # Setup distributed training
- is_distributed = setup_distributed(params)
+ is_distributed = setup_distributed(params) # Setup distributed training
rank = params.local_rank if is_distributed else 0
world_size = params.world_size if is_distributed else 1
-
- # Setup device and logger
device = torch.device(f"cuda:{params.local_rank}" if is_distributed else "cuda")
- logger = setup_logger(rank, world_size)
+ logger = setup_logger(rank, world_size) # ddp logger
torch.set_float32_matmul_precision("high")
if params.on_cudnn_benchmark:
@@ -593,6 +400,16 @@ def main():
hidden_dim=params.hidden_dim,
coff=params.coff,
pre_kl=params.pre_kl,
+ use_kl_annealing=params.use_kl_annealing,
+ kl_annealing_mode=params.kl_annealing_mode,
+ cycle_period=params.cycle_period,
+ max_kl_weight=params.max_kl_weight,
+ min_kl_weight=params.min_kl_weight,
+ ratio=params.ratio,
+ use_capacity_increase=params.use_capacity_increase,
+ gamma=params.gamma,
+ max_capacity=params.max_capacity,
+ capacity_max_iter=params.capacity_max_iter,
).to(device)
if is_distributed:
@@ -614,9 +431,7 @@ def main():
# Training mode selection
if params.mode == "seq_train":
if rank == 0:
- logger.opt(colors=True).info(
- f"✅ Mode: sequential execution [progress 1 >> {params.num_ladders}]"
- )
+ logger.opt(colors=True).info(f"✅ Mode: sequential execution [progress 1 >> {params.num_ladders}]")
for i in range(1, params.num_ladders + 1):
if is_distributed:
@@ -624,11 +439,10 @@ def main():
dist.barrier()
# Update sequence number
+ params.train_seq = i
if is_distributed:
- params.train_seq = i
model.module.train_seq = i
else:
- params.train_seq = i
model.train_seq = i
if params.use_wandb:
@@ -675,14 +489,16 @@ def main():
elif params.mode == "indep_train":
logger.info(f"Current trainig progress >> {params.train_seq}")
if rank == 0:
- logger.opt(colors=True).info(
- f"✅ Mode: independent execution [progress {params.train_seq}]"
- )
+ logger.opt(colors=True).info(f"✅ Mode: independent execution [progress {params.train_seq}]")
if is_distributed:
torch.cuda.synchronize()
dist.barrier()
+ if params.use_wandb:
+ hash_str = os.urandom(8).hex().upper()
+ init_wandb(params, hash_str)
+
# Load checkpoint if needed
if params.train_seq >= 2:
prev_checkpoint = os.path.join(
@@ -719,7 +535,7 @@ def main():
dist.barrier()
elif params.mode == "traverse":
- logger.opt(colors=True).info(f"✅ Mode: traverse execution [progress 1 {params.num_ladders}]")
+ logger.opt(colors=True).info(f"✅ Mode: traverse execution [progress 1 {params.num_ladders}]")
try:
model, optimizer, scaler = load_checkpoint(
model=model,
diff --git a/src/utils.py b/src/utils.py
index c513405..716db70 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,10 +1,51 @@
import argparse
import time
+import os
from typing import Any
+import wandb
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torchvision
+import imageio
+from loguru import logger
+import torch.optim as optim
+import torch_optimizer as jettify_optim
+
from ddp_utils import setup_logger
+def init_wandb(params, hash):
+ if params.use_wandb:
+ if wandb.run is not None:
+ wandb.finish()
+
+ run_id = None
+ if params.local_rank == 0:
+ wandb.init(
+ project=params.wandb_project,
+ config=vars(params),
+ name=f"{params.dataset.upper()}_PROGRESS{params.train_seq}_{hash}",
+ settings=wandb.Settings(start_method="thread", _disable_stats=True),
+ )
+ run_id = wandb.run.id
+
+ if params.distributed:
+ object_list = [run_id if params.local_rank == 0 else None]
+ dist.broadcast_object_list(object_list, src=0)
+ run_id = object_list[0]
+
+ if params.local_rank != 0:
+ wandb.init(
+ project=params.wandb_project,
+ id=run_id,
+ resume="allow",
+ settings=wandb.Settings(start_method="thread", _disable_stats=True),
+ )
+
+
def exec_time(func):
"""Decorates a function to measure its execution time in hours and minutes."""
@@ -65,3 +106,131 @@ def add_dataclass_args(parser: argparse.ArgumentParser, dataclass_type: Any):
default=field_info.default,
help=f"Set {field_info.name} to a value of type {field_info.type.__name__}",
)
+
+
+def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger):
+ """Load a model checkpoint with proper device management."""
+ try:
+ checkpoint = torch.load(
+ checkpoint_path,
+ map_location=device,
+ weights_only=True,
+ )
+
+ # Load model state dict
+ if hasattr(model, "module"):
+ model.module.load_state_dict(checkpoint["model_state_dict"])
+ else:
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+
+ # Load optimizer state dict
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(device)
+
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+
+ if scaler is not None and "scaler_state_dict" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler_state_dict"])
+
+ logger.info(
+ f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})"
+ )
+
+ return model, optimizer, scaler
+ except Exception as e:
+ logger.error(f"Failed to load checkpoint: {str(e)}")
+ return model, optimizer, scaler
+
+
+def save_reconstruction(inputs, reconstructions, save_path):
+ """Save a grid of original and reconstructed images"""
+ batch_size = min(8, inputs.shape[0])
+ inputs = inputs[:batch_size].float()
+ reconstructions = reconstructions[:batch_size].float()
+ comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]])
+
+ # Denormalize and convert to numpy
+ images = comparison.cpu().detach()
+ images = torch.clamp(images, 0, 1)
+ grid = torchvision.utils.make_grid(images, nrow=batch_size)
+ image = grid.permute(1, 2, 0).numpy()
+
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ imageio.imwrite(save_path, (image * 255).astype("uint8"))
+
+
+def save_input_image(inputs: torch.Tensor, save_dir: str, seq: int, size: int = 96) -> str:
+ input_path = os.path.join(save_dir, f"traverse_input_seq{seq}.png")
+ os.makedirs(save_dir, exist_ok=True)
+
+ input_img = inputs[0].cpu().float()
+ input_img = torch.clamp(input_img, 0, 1)
+
+ if input_img.shape[-1] != size:
+ input_img = F.interpolate(
+ input_img.unsqueeze(0),
+ size=size,
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(0)
+
+ if input_img.shape[0] == 1:
+ input_img = input_img.repeat(3, 1, 1)
+
+ input_array = (input_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
+ imageio.imwrite(input_path, input_array)
+ return input_path
+
+
+def get_optimizer(model, params):
+ """Get the optimizer based on the parameter settings"""
+ optimizer_params = {
+ "params": model.parameters(),
+ "lr": params.learning_rate,
+ }
+
+ # Adam, Lamb, DiffGrad
+ extra_args_common = {
+ "betas": getattr(params, "betas", (0.9, 0.999)),
+ "eps": getattr(params, "eps", 1e-8),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ }
+
+ extra_args_adamw = {
+ "betas": getattr(params, "betas", (0.9, 0.999)),
+ "eps": getattr(params, "eps", 1e-8),
+ "weight_decay": getattr(params, "weight_decay", 0.01),
+ }
+
+ # SGD
+ extra_args_sgd = {
+ "momentum": getattr(params, "momentum", 0),
+ "dampening": getattr(params, "dampening", 0),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ "nesterov": getattr(params, "nesterov", False),
+ }
+
+ # MADGRAD
+ extra_args_madgrad = {
+ "momentum": getattr(params, "momentum", 0.9),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ "eps": getattr(params, "eps", 1e-6),
+ }
+
+ optimizers = {
+ "adam": (optim.Adam, extra_args_common),
+ "adamw": (optim.AdamW, extra_args_adamw),
+ "sgd": (optim.SGD, extra_args_sgd),
+ "lamb": (jettify_optim.Lamb, extra_args_common),
+ "diffgrad": (jettify_optim.DiffGrad, extra_args_common),
+ "madgrad": (jettify_optim.MADGRAD, extra_args_madgrad),
+ }
+
+ optimizer_cls, extra_args = optimizers.get(params.optim.lower(), (optim.Adam, extra_args_common))
+ if params.optim.lower() not in optimizers:
+ logger.warning(f"Unsupported optimizer '{params.optim}', using 'Adam' optimizer instead.")
+ optimizer = optimizer_cls(**optimizer_params, **extra_args)
+
+ return optimizer