diff --git a/src/dataset.py b/src/dataset.py index 904e6d4..5ca8d26 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -13,6 +13,7 @@ from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout from torch.utils.data import DataLoader, Dataset from tqdm import tqdm +from torch.utils.data.distributed import DistributedSampler @dataclass @@ -46,6 +47,69 @@ def get_config(cls, dataset_name: str) -> "DatasetConfig": return configs[dataset_name] +def get_dataset(params, logger): + """Load dataset with distributed support""" + dataset_classes = { + "mnist": MNIST, + "fashionmnist": FashionMNIST, + "shapes3d": Shapes3D, + "dsprites": DSprites, + "celeba": CelebA, + "flowers102": Flowers102, + "dtd": DTD, + "imagenet": ImageNet, + "mpi3d": MPI3D, + "ident3d": Ident3D, + } + + if params.dataset not in dataset_classes: + raise ValueError(f"Unknown dataset: {params.dataset}") + + dataset_class = dataset_classes[params.dataset] + + try: + if params.dataset == "mpi3d": + variant = getattr(params, "mpi3d_variant", "toy") + dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant) + else: + dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4) + + config = dataset.get_config() + params.chn_num = config.chn_num + params.image_size = config.image_size + + train_loader, test_loader = dataset.get_data_loader() + if params.distributed: + train_sampler = DistributedSampler( + train_loader.dataset, + num_replicas=params.world_size, + rank=params.local_rank, + shuffle=True, + drop_last=True, + ) + + train_loader = torch.utils.data.DataLoader( + train_loader.dataset, + batch_size=params.batch_size, + sampler=train_sampler, + num_workers=params.num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + ) + + if params.local_rank == 0: + logger.info(f"Dataset {params.dataset} loaded with distributed sampler") + else: + logger.info(f"Dataset {params.dataset} loaded") + + return train_loader, test_loader + + except Exception as e: + logger.error(f"Failed to load dataset: {str(e)}") + raise + + def download_file(url, filename): """Download file with progress bar""" try: diff --git a/src/provlae.py b/src/provlae.py index c0abec5..ff00078 100644 --- a/src/provlae.py +++ b/src/provlae.py @@ -1,4 +1,6 @@ from math import ceil, log2 +import numpy as np +import math import torch import torch.nn as nn @@ -21,6 +23,16 @@ def __init__( pre_kl=True, coff=0.5, train_seq=1, + use_kl_annealing=False, + kl_annealing_mode="linear", + cycle_period=4, + max_kl_weight=1, + min_kl_weight=0.1, + ratio=1.0, + use_capacity_increase=False, + gamma=1000.0, + max_capacity=25, + capacity_max_iter=1e-5, ): super(ProVLAE, self).__init__() @@ -44,6 +56,22 @@ def __init__( self.fade_in_duration = fade_in_duration self.train_seq = min(train_seq, self.num_ladders) + # for kl annealing + self.use_kl_annealing = use_kl_annealing + self.kl_annealing_mode = kl_annealing_mode + self.current_epoch = None + self.num_epochs = None + self.cycle_period = cycle_period + self.max_kl_weight = max_kl_weight + self.min_kl_weight = min_kl_weight + self.ratio = ratio + + # Improving disentangling in β-VAE with controlled capacity increase + self.use_capacity_increase = use_capacity_increase + self.gamma = gamma + self.C_max = torch.Tensor([max_capacity]) + self.C_stop_iter = capacity_max_iter + # Calculate encoder sizes self.encoder_sizes = [self.target_size] current_size = self.target_size @@ -56,8 +84,7 @@ def __init__( self.hidden_dims.extend([self.hidden_dims[-1]] * (self.num_ladders - len(self.hidden_dims))) self.hidden_dims = self.hidden_dims[: self.num_ladders] - # Base setup - self.activation = nn.ELU() # or LeakyReLU + self.activation = nn.LeakyReLU() # or ELU self.q_dist = Normal self.x_dist = Bernoulli self.prior_params = nn.Parameter(torch.zeros(self.z_dim, 2)) @@ -170,16 +197,70 @@ def _sample_latent(self, z_params): return z_mean + eps * std, z_mean, z_log_var def _kl_divergence(self, z_mean, z_log_var): - return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) + return -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) def fade_in_alpha(self, step): if step > self.fade_in_duration: return 1.0 return step / self.fade_in_duration + def frange_cycle_linear(self, start, stop, n_epoch, n_cycle=4, ratio=0.5): + L = np.ones(n_epoch) + period = n_epoch / n_cycle + step = (stop - start) / (period / ratio) + + for c in range(n_cycle): + v, i = start, 0 + while v <= stop and int(i + c * period) < n_epoch: + L[int(i + c * period)] = v + v += step + i += 1 + return L + + def frange_cycle_sigmoid(self, start, stop, n_epoch, n_cycle=4, ratio=0.5): + L = np.ones(n_epoch) + period = n_epoch / n_cycle + step = (stop - start) / (period * ratio) # step is in [0,1] + + # transform into [-6, 6] for plots: v*12.-6. + + for c in range(n_cycle): + + v, i = start, 0 + while v <= stop: + L[int(i + c * period)] = 1.0 / (1.0 + np.exp(-(v * 12.0 - 6.0))) + v += step + i += 1 + return L + + def frange_cycle_cosine(self, start, stop, n_epoch, n_cycle=4, ratio=0.5): + L = np.ones(n_epoch) + period = n_epoch / n_cycle + step = (stop - start) / (period * ratio) # step is in [0,1] + + # transform into [0, pi] for plots: + + for c in range(n_cycle): + + v, i = start, 0 + while v <= stop: + L[int(i + c * period)] = 0.5 - 0.5 * math.cos(v * math.pi) + v += step + i += 1 + return L + + def cycle_kl_weights(self, epoch, n_epoch, cycle_period=4, max_kl_weight=1.0, min_kl_weight=0.1, ratio=0.5): + if self.kl_annealing_mode == "linear": + kl_weights = self.frange_cycle_linear(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio) + if self.kl_annealing_mode == "sigmoid": + kl_weights = self.frange_cycle_sigmoid(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio) + if self.kl_annealing_mode == "cosine": + kl_weights = self.frange_cycle_cosine(min_kl_weight, max_kl_weight, n_epoch, cycle_period, ratio) + + return kl_weights[epoch] + def encode(self, x): - # Store original size - original_size = x.size()[-2:] + original_size = x.size()[-2:] # Store original size # Resize to target size if original_size != (self.target_size, self.target_size): @@ -222,10 +303,8 @@ def decode(self, z_list, original_size): f = f * self.fade_in features.append(f) - # Start from deepest layer - x = features[-1] - # Progressive decoding with explicit size management + x = features[-1] # Start from deepest layer for i in range(self.num_ladders - 2, -1, -1): # Ensure feature maps have matching spatial dimensions target_size = features[i].size(-1) @@ -241,8 +320,7 @@ def decode(self, z_list, original_size): for up_layer in self.additional_ups: x = up_layer(x) - # Final convolution - x = self.output_layer(x) + x = self.output_layer(x) # Final convolution # Resize to original input size if original_size != (x.size(-2), x.size(-1)): @@ -252,11 +330,17 @@ def decode(self, z_list, original_size): def forward(self, x, step=0): self.fade_in = self.fade_in_alpha(step) + kl_weight = self.cycle_kl_weights( + epoch=self.current_epoch, + n_epoch=self.num_epochs, + cycle_period=self.cycle_period, + max_kl_weight=self.max_kl_weight, + min_kl_weight=self.min_kl_weight, + ratio=self.ratio, + ) - # Encode z_params, original_size = self.encode(x) - # Calculate KL divergence latent_losses = [] zs = [] for z, z_mean, z_log_var in z_params: @@ -264,23 +348,32 @@ def forward(self, x, step=0): zs.append(z) latent_loss = sum(latent_losses) - - # Decode x_recon = self.decode(zs, original_size) # Reconstruction loss - bce_loss = nn.BCEWithLogitsLoss(reduction="sum") + bce_loss = nn.BCEWithLogitsLoss(reduction="mean") recon_loss = bce_loss(x_recon, x) - # Calculate final loss - if self.pre_kl: - active_latents = latent_losses[self.train_seq - 1 :] - inactive_latents = latent_losses[: self.train_seq - 1] - loss = recon_loss + self.beta * sum(active_latents) + self.coff * sum(inactive_latents) + # prekl loss + active_latents = latent_losses[self.train_seq - 1 :] + inactive_latents = latent_losses[: self.train_seq - 1] + if self.use_kl_annealing: + if self.use_capacity_increase: + # https://arxiv.org/pdf/1804.03599.pdf + self.C_max = self.C_max.to(x.device) + C = torch.clamp(self.C_max / self.C_stop_iter * step, 0, self.C_max.data[0]) + kl_term = self.gamma * kl_weight * (sum(active_latents) - C).abs() + else: + # https://openreview.net/forum?id=Sy2fzU9gl + kl_term = kl_weight * self.beta * sum(active_latents) else: - loss = recon_loss + self.beta * latent_loss + kl_term = self.beta * sum(active_latents) + + loss = recon_loss + kl_term + if self.pre_kl: + loss += self.coff * sum(inactive_latents) - return torch.sigmoid(x_recon), loss, latent_loss, recon_loss + return torch.sigmoid(x_recon), loss, kl_term, recon_loss, kl_weight def inference(self, x): with torch.no_grad(): diff --git a/src/scripts/run_dsprites.sh b/src/scripts/run_dsprites.sh index 027eedb..33cee2e 100644 --- a/src/scripts/run_dsprites.sh +++ b/src/scripts/run_dsprites.sh @@ -9,7 +9,7 @@ torchrun --nproc_per_node=2 --master_port=29502 src/train.py \ --batch_size 256 \ --num_epochs 30 \ --learning_rate 5e-4 \ - --beta 8 \ + --beta 3 \ --z_dim 2 \ --coff 0.5 \ --pre_kl \ diff --git a/src/scripts/run_imagenet.sh b/src/scripts/run_imagenet.sh index cf11d39..d1bf44c 100644 --- a/src/scripts/run_imagenet.sh +++ b/src/scripts/run_imagenet.sh @@ -6,12 +6,12 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \ --dataset imagenet \ --optim adamw \ --num_ladders 4 \ - --batch_size 256 \ - --num_epochs 100 \ + --batch_size 128 \ + --num_epochs 30 \ --learning_rate 5e-4 \ --beta 1 \ - --z_dim 4 \ - --coff 0.5 \ + --z_dim 8 \ + --coff 0.1 \ --pre_kl \ --hidden_dim 64 \ --fade_in_duration 5000 \ diff --git a/src/scripts/run_mnist.sh b/src/scripts/run_mnist.sh index bd651ac..df9c01c 100644 --- a/src/scripts/run_mnist.sh +++ b/src/scripts/run_mnist.sh @@ -2,21 +2,29 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \ --distributed \ - --mode seq_train \ - --dataset mnist \ + --mode indep_train \ + --train_seq 3 \ + --dataset shapes3d \ --optim adamw \ --num_ladders 3 \ - --batch_size 64 \ - --num_epochs 5 \ + --batch_size 128 \ + --num_epochs 16 \ --learning_rate 5e-4 \ - --beta 3 \ - --z_dim 2 \ + --beta 1 \ + --z_dim 3 \ --coff 0.5 \ --pre_kl \ - --hidden_dim 64 \ + --hidden_dim 32 \ --fade_in_duration 5000 \ - --output_dir ./output/mnist/ \ - --data_path ./data/mnist/ \ + --output_dir ./output/shapes3d/ \ + --data_path ./data/ \ --use_wandb \ --wandb_project PRO-VLAE \ + --use_kl_annealing \ + --kl_annealing_mode sigmoid \ + --cycle_period 4 \ + --ratio 0.5 \ + --max_kl_weight 1.0 \ + --min_kl_weight 0.1 \ + --num_workers 16 \ No newline at end of file diff --git a/src/train.py b/src/train.py index 39d506a..1a06a08 100644 --- a/src/train.py +++ b/src/train.py @@ -1,27 +1,31 @@ import argparse import os -import sys from dataclasses import dataclass, field +import math import imageio.v3 as imageio import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F -import torch.optim as optim -import torch_optimizer as jettify_optim import torchvision import wandb -from loguru import logger from PIL import Image, ImageDraw, ImageFont from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm -from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D -from ddp_utils import cleanup_distributed, setup_distributed, setup_logger from provlae import ProVLAE -from utils import add_dataclass_args, exec_time +from dataset import get_dataset +from ddp_utils import cleanup_distributed, setup_distributed, setup_logger +from utils import ( + init_wandb, + get_optimizer, + add_dataclass_args, + exec_time, + save_input_image, + save_reconstruction, + load_checkpoint, +) @dataclass @@ -50,6 +54,16 @@ class HyperParameters: hidden_dim: int = field(default=32) coff: float = field(default=0.5) pre_kl: bool = field(default=True) + use_kl_annealing: bool = field(default=False) + kl_annealing_mode: str = field(default="linear") + cycle_period: int = field(default=4) + max_kl_weight: float = field(default=1.0) + min_kl_weight: float = field(default=0.1) + ratio: float = field(default=1.0) + use_capacity_increase: bool = field(default=False) + gamma: float = field(default=1000.0) + max_capacity: int = field(default=25) + capacity_max_iter: float = field(default=1e-5) @dataclass @@ -91,145 +105,6 @@ def parse_arguments(): return parser.parse_args() -def get_dataset(params, logger): - """Load dataset with distributed support""" - dataset_classes = { - "mnist": MNIST, - "fashionmnist": FashionMNIST, - "shapes3d": Shapes3D, - "dsprites": DSprites, - "celeba": CelebA, - "flowers102": Flowers102, - "dtd": DTD, - "imagenet": ImageNet, - "mpi3d": MPI3D, - "ident3d": Ident3D, - } - - if params.dataset not in dataset_classes: - raise ValueError(f"Unknown dataset: {params.dataset}") - - dataset_class = dataset_classes[params.dataset] - - try: - if params.dataset == "mpi3d": - variant = getattr(params, "mpi3d_variant", "toy") - dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant) - else: - dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4) - - config = dataset.get_config() - params.chn_num = config.chn_num - params.image_size = config.image_size - - train_loader, test_loader = dataset.get_data_loader() - if params.distributed: - train_sampler = DistributedSampler( - train_loader.dataset, - num_replicas=params.world_size, - rank=params.local_rank, - shuffle=True, - drop_last=True, - ) - - train_loader = torch.utils.data.DataLoader( - train_loader.dataset, - batch_size=params.batch_size, - sampler=train_sampler, - num_workers=params.num_workers, - pin_memory=True, - drop_last=True, - persistent_workers=True, - ) - - if params.local_rank == 0: - logger.info(f"Dataset {params.dataset} loaded with distributed sampler") - else: - logger.info(f"Dataset {params.dataset} loaded") - - return train_loader, test_loader - - except Exception as e: - logger.error(f"Failed to load dataset: {str(e)}") - raise - - -def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger): - """Load a model checkpoint with proper device management.""" - try: - checkpoint = torch.load( - checkpoint_path, - map_location=device, - weights_only=True, - ) - - # Load model state dict - if hasattr(model, "module"): - model.module.load_state_dict(checkpoint["model_state_dict"]) - else: - model.load_state_dict(checkpoint["model_state_dict"], strict=False) - - # Load optimizer state dict - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(device) - - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - - if scaler is not None and "scaler_state_dict" in checkpoint: - scaler.load_state_dict(checkpoint["scaler_state_dict"]) - - logger.info( - f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})" - ) - - return model, optimizer, scaler - except Exception as e: - logger.error(f"Failed to load checkpoint: {str(e)}") - return model, optimizer, scaler - - -def save_reconstruction(inputs, reconstructions, save_path): - """Save a grid of original and reconstructed images""" - batch_size = min(8, inputs.shape[0]) - inputs = inputs[:batch_size].float() - reconstructions = reconstructions[:batch_size].float() - comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]]) - - # Denormalize and convert to numpy - images = comparison.cpu().detach() - images = torch.clamp(images, 0, 1) - grid = torchvision.utils.make_grid(images, nrow=batch_size) - image = grid.permute(1, 2, 0).numpy() - - os.makedirs(os.path.dirname(save_path), exist_ok=True) - imageio.imwrite(save_path, (image * 255).astype("uint8")) - - -def save_input_image(inputs: torch.Tensor, save_dir: str, seq: int, size: int = 96) -> str: - input_path = os.path.join(save_dir, f"traverse_input_seq{seq}.png") - os.makedirs(save_dir, exist_ok=True) - - input_img = inputs[0].cpu().float() - input_img = torch.clamp(input_img, 0, 1) - - if input_img.shape[-1] != size: - input_img = F.interpolate( - input_img.unsqueeze(0), - size=size, - mode="bilinear", - align_corners=False, - ).squeeze(0) - - if input_img.shape[0] == 1: - input_img = input_img.repeat(3, 1, 1) - - input_array = (input_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) - imageio.imwrite(input_path, input_array) - return input_path - - def create_latent_traversal(model, data_loader, save_path, device, params): """Create and save organized latent traversal GIF with optimized layout""" model.eval() @@ -242,13 +117,16 @@ def create_latent_traversal(model, data_loader, save_path, device, params): inputs, _ = next(iter(data_loader)) # Get a single batch of images inputs = inputs[0:1].to(device) - input_path = save_input_image(inputs.cpu(), os.path.join(params.output_dir, params.input_dir), params.train_seq) + # save traverse inputs + input_path = save_input_image( + inputs.cpu(), os.path.join(params.output_dir, params.input_dir), params.train_seq, params.image_size + ) # Get latent representations with torch.amp.autocast(device_type="cuda", enabled=False): latent_vars = [z[0] for z in model.inference(inputs)] - traverse_range = torch.linspace(-1.5, 1.5, 15).to(device) + traverse_range = torch.linspace(-2.5, 2.5, 10).to(device) # Image layout parameters img_size = 96 # Base image size @@ -356,6 +234,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No if hasattr(model, "module"): model.module.to(device) + model.module.num_epochs = params.num_epochs else: model.to(device) @@ -364,6 +243,11 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No logger.info(f"Start training [progress {params.train_seq}]") for epoch in range(params.num_epochs): + if hasattr(model, "module"): + model.module.current_epoch = epoch + else: + model.current_epoch = epoch + if params.distributed: data_loader.sampler.set_epoch(epoch) @@ -378,7 +262,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No inputs = inputs.to(device, non_blocking=True) with torch.amp.autocast(device_type="cuda", dtype=autocast_dtype): - x_recon, loss, latent_loss, recon_loss = model(inputs, step=global_step) + x_recon, loss, latent_loss, recon_loss, kl_weight = model(inputs, step=global_step) optimizer.zero_grad() if scaler is not None: @@ -391,16 +275,17 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No if params.local_rank == 0 or not params.distributed: # Only show progress on main process pbar.set_postfix( - total_loss=f"{loss.item():.2f}", - latent_loss=f"{latent_loss:.2f}", - recon_loss=f"{recon_loss:.2f}", + total_loss=f"{loss.item():.5f}", + latent_loss=f"{latent_loss.item():.5f}", + recon_loss=f"{recon_loss.item():.5f}", ) if params.use_wandb and params.distributed: metrics = { - f"loss/rank_{params.local_rank}": loss.item(), - f"latent_loss/rank_{params.local_rank}": latent_loss.item(), - f"recon_loss/rank_{params.local_rank}": recon_loss.item(), + f"ELBO/rank_{params.local_rank}": loss.item(), + f"KL Term/rank_{params.local_rank}": latent_loss.item(), + f"Reconstruction Error/rank_{params.local_rank}": recon_loss.item(), + f"KL Weight/rank_{params.local_rank}": kl_weight.item(), } all_metrics = [None] * params.world_size @@ -412,7 +297,12 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No combined_metrics.update(rank_metrics) wandb.log(combined_metrics, step=global_step) elif params.use_wandb: - metrics = {"loss": loss.item(), "latent_loss": latent_loss.item(), "recon_loss": recon_loss.item()} + metrics = { + "ELBO": loss.item(), + "KL Term": latent_loss.item(), + "Reconstruction Error": recon_loss.item(), + "KL Weight": kl_weight.item(), + } wandb.log(metrics, step=global_step) global_step += 1 @@ -432,7 +322,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No save_reconstruction(inputs, x_recon, recon_path) input_path = create_latent_traversal(model, data_loader, traverse_path, device, params) - # reconstruction and traversal images + # reconstruction and traversal images (Media) if params.use_wandb and (params.local_rank == 0 or not params.distributed): wandb.log( { @@ -458,100 +348,17 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}") -def get_optimizer(model, params): - """Get the optimizer based on the parameter settings""" - optimizer_params = { - "params": model.parameters(), - "lr": params.learning_rate, - } - - # Adam, Lamb, DiffGrad - extra_args_common = { - "betas": getattr(params, "betas", (0.9, 0.999)), - "eps": getattr(params, "eps", 1e-8), - "weight_decay": getattr(params, "weight_decay", 0), - } - - extra_args_adamw = { - "betas": getattr(params, "betas", (0.9, 0.999)), - "eps": getattr(params, "eps", 1e-8), - "weight_decay": getattr(params, "weight_decay", 0.01), - } - - # SGD - extra_args_sgd = { - "momentum": getattr(params, "momentum", 0), - "dampening": getattr(params, "dampening", 0), - "weight_decay": getattr(params, "weight_decay", 0), - "nesterov": getattr(params, "nesterov", False), - } - - # MADGRAD - extra_args_madgrad = { - "momentum": getattr(params, "momentum", 0.9), - "weight_decay": getattr(params, "weight_decay", 0), - "eps": getattr(params, "eps", 1e-6), - } - - optimizers = { - "adam": (optim.Adam, extra_args_common), - "adamw": (optim.AdamW, extra_args_adamw), - "sgd": (optim.SGD, extra_args_sgd), - "lamb": (jettify_optim.Lamb, extra_args_common), - "diffgrad": (jettify_optim.DiffGrad, extra_args_common), - "madgrad": (jettify_optim.MADGRAD, extra_args_madgrad), - } - - optimizer_cls, extra_args = optimizers.get(params.optim.lower(), (optim.Adam, extra_args_common)) - if params.optim.lower() not in optimizers: - logger.warning(f"Unsupported optimizer '{params.optim}', using 'Adam' optimizer instead.") - optimizer = optimizer_cls(**optimizer_params, **extra_args) - - return optimizer - - -def init_wandb(params, hash): - if params.use_wandb: - if wandb.run is not None: - wandb.finish() - - run_id = None - if params.local_rank == 0: - logger.debug(f"Current run ID: {hash}") - wandb.init( - project=params.wandb_project, - config=vars(params), - name=f"{params.dataset.upper()}_PROGRESS{params.train_seq}_{hash}", - settings=wandb.Settings(start_method="thread", _disable_stats=True), - ) - run_id = wandb.run.id - - if params.distributed: - object_list = [run_id if params.local_rank == 0 else None] - dist.broadcast_object_list(object_list, src=0) - run_id = object_list[0] - - if params.local_rank != 0: - wandb.init( - project=params.wandb_project, - id=run_id, - resume="allow", - settings=wandb.Settings(start_method="thread", _disable_stats=True), - ) - - def main(): - params = parse_arguments() + params = parse_arguments() # hyperparameter and training config + + """TODO: fix random seed""" try: - # Setup distributed training - is_distributed = setup_distributed(params) + is_distributed = setup_distributed(params) # Setup distributed training rank = params.local_rank if is_distributed else 0 world_size = params.world_size if is_distributed else 1 - - # Setup device and logger device = torch.device(f"cuda:{params.local_rank}" if is_distributed else "cuda") - logger = setup_logger(rank, world_size) + logger = setup_logger(rank, world_size) # ddp logger torch.set_float32_matmul_precision("high") if params.on_cudnn_benchmark: @@ -593,6 +400,16 @@ def main(): hidden_dim=params.hidden_dim, coff=params.coff, pre_kl=params.pre_kl, + use_kl_annealing=params.use_kl_annealing, + kl_annealing_mode=params.kl_annealing_mode, + cycle_period=params.cycle_period, + max_kl_weight=params.max_kl_weight, + min_kl_weight=params.min_kl_weight, + ratio=params.ratio, + use_capacity_increase=params.use_capacity_increase, + gamma=params.gamma, + max_capacity=params.max_capacity, + capacity_max_iter=params.capacity_max_iter, ).to(device) if is_distributed: @@ -614,9 +431,7 @@ def main(): # Training mode selection if params.mode == "seq_train": if rank == 0: - logger.opt(colors=True).info( - f"✅ Mode: sequential execution [progress 1 >> {params.num_ladders}]" - ) + logger.opt(colors=True).info(f"✅ Mode: sequential execution [progress 1 >> {params.num_ladders}]") for i in range(1, params.num_ladders + 1): if is_distributed: @@ -624,11 +439,10 @@ def main(): dist.barrier() # Update sequence number + params.train_seq = i if is_distributed: - params.train_seq = i model.module.train_seq = i else: - params.train_seq = i model.train_seq = i if params.use_wandb: @@ -675,14 +489,16 @@ def main(): elif params.mode == "indep_train": logger.info(f"Current trainig progress >> {params.train_seq}") if rank == 0: - logger.opt(colors=True).info( - f"✅ Mode: independent execution [progress {params.train_seq}]" - ) + logger.opt(colors=True).info(f"✅ Mode: independent execution [progress {params.train_seq}]") if is_distributed: torch.cuda.synchronize() dist.barrier() + if params.use_wandb: + hash_str = os.urandom(8).hex().upper() + init_wandb(params, hash_str) + # Load checkpoint if needed if params.train_seq >= 2: prev_checkpoint = os.path.join( @@ -719,7 +535,7 @@ def main(): dist.barrier() elif params.mode == "traverse": - logger.opt(colors=True).info(f"✅ Mode: traverse execution [progress 1 {params.num_ladders}]") + logger.opt(colors=True).info(f"✅ Mode: traverse execution [progress 1 {params.num_ladders}]") try: model, optimizer, scaler = load_checkpoint( model=model, diff --git a/src/utils.py b/src/utils.py index c513405..716db70 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,10 +1,51 @@ import argparse import time +import os from typing import Any +import wandb +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchvision +import imageio +from loguru import logger +import torch.optim as optim +import torch_optimizer as jettify_optim + from ddp_utils import setup_logger +def init_wandb(params, hash): + if params.use_wandb: + if wandb.run is not None: + wandb.finish() + + run_id = None + if params.local_rank == 0: + wandb.init( + project=params.wandb_project, + config=vars(params), + name=f"{params.dataset.upper()}_PROGRESS{params.train_seq}_{hash}", + settings=wandb.Settings(start_method="thread", _disable_stats=True), + ) + run_id = wandb.run.id + + if params.distributed: + object_list = [run_id if params.local_rank == 0 else None] + dist.broadcast_object_list(object_list, src=0) + run_id = object_list[0] + + if params.local_rank != 0: + wandb.init( + project=params.wandb_project, + id=run_id, + resume="allow", + settings=wandb.Settings(start_method="thread", _disable_stats=True), + ) + + def exec_time(func): """Decorates a function to measure its execution time in hours and minutes.""" @@ -65,3 +106,131 @@ def add_dataclass_args(parser: argparse.ArgumentParser, dataclass_type: Any): default=field_info.default, help=f"Set {field_info.name} to a value of type {field_info.type.__name__}", ) + + +def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger): + """Load a model checkpoint with proper device management.""" + try: + checkpoint = torch.load( + checkpoint_path, + map_location=device, + weights_only=True, + ) + + # Load model state dict + if hasattr(model, "module"): + model.module.load_state_dict(checkpoint["model_state_dict"]) + else: + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + + # Load optimizer state dict + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if scaler is not None and "scaler_state_dict" in checkpoint: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + logger.info( + f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})" + ) + + return model, optimizer, scaler + except Exception as e: + logger.error(f"Failed to load checkpoint: {str(e)}") + return model, optimizer, scaler + + +def save_reconstruction(inputs, reconstructions, save_path): + """Save a grid of original and reconstructed images""" + batch_size = min(8, inputs.shape[0]) + inputs = inputs[:batch_size].float() + reconstructions = reconstructions[:batch_size].float() + comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]]) + + # Denormalize and convert to numpy + images = comparison.cpu().detach() + images = torch.clamp(images, 0, 1) + grid = torchvision.utils.make_grid(images, nrow=batch_size) + image = grid.permute(1, 2, 0).numpy() + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + imageio.imwrite(save_path, (image * 255).astype("uint8")) + + +def save_input_image(inputs: torch.Tensor, save_dir: str, seq: int, size: int = 96) -> str: + input_path = os.path.join(save_dir, f"traverse_input_seq{seq}.png") + os.makedirs(save_dir, exist_ok=True) + + input_img = inputs[0].cpu().float() + input_img = torch.clamp(input_img, 0, 1) + + if input_img.shape[-1] != size: + input_img = F.interpolate( + input_img.unsqueeze(0), + size=size, + mode="bilinear", + align_corners=False, + ).squeeze(0) + + if input_img.shape[0] == 1: + input_img = input_img.repeat(3, 1, 1) + + input_array = (input_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + imageio.imwrite(input_path, input_array) + return input_path + + +def get_optimizer(model, params): + """Get the optimizer based on the parameter settings""" + optimizer_params = { + "params": model.parameters(), + "lr": params.learning_rate, + } + + # Adam, Lamb, DiffGrad + extra_args_common = { + "betas": getattr(params, "betas", (0.9, 0.999)), + "eps": getattr(params, "eps", 1e-8), + "weight_decay": getattr(params, "weight_decay", 0), + } + + extra_args_adamw = { + "betas": getattr(params, "betas", (0.9, 0.999)), + "eps": getattr(params, "eps", 1e-8), + "weight_decay": getattr(params, "weight_decay", 0.01), + } + + # SGD + extra_args_sgd = { + "momentum": getattr(params, "momentum", 0), + "dampening": getattr(params, "dampening", 0), + "weight_decay": getattr(params, "weight_decay", 0), + "nesterov": getattr(params, "nesterov", False), + } + + # MADGRAD + extra_args_madgrad = { + "momentum": getattr(params, "momentum", 0.9), + "weight_decay": getattr(params, "weight_decay", 0), + "eps": getattr(params, "eps", 1e-6), + } + + optimizers = { + "adam": (optim.Adam, extra_args_common), + "adamw": (optim.AdamW, extra_args_adamw), + "sgd": (optim.SGD, extra_args_sgd), + "lamb": (jettify_optim.Lamb, extra_args_common), + "diffgrad": (jettify_optim.DiffGrad, extra_args_common), + "madgrad": (jettify_optim.MADGRAD, extra_args_madgrad), + } + + optimizer_cls, extra_args = optimizers.get(params.optim.lower(), (optim.Adam, extra_args_common)) + if params.optim.lower() not in optimizers: + logger.warning(f"Unsupported optimizer '{params.optim}', using 'Adam' optimizer instead.") + optimizer = optimizer_cls(**optimizer_params, **extra_args) + + return optimizer