diff --git a/README.md b/README.md
index 59b7e28..3d14fc9 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,9 @@ This implementation introduces dynamic size management for arbitrary input image
## Installation
We recommend using [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) (via [miniforge](https://github.com/conda-forge/miniforge)) for faster installation of dependencies, but you can also use [conda](https://docs.anaconda.com/miniconda/miniconda-install/).
```bash
+git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
+cd pytorch-proVLAE
+
mamba env create -f env.yaml # or conda
mamba activate torch-provlae
```
@@ -48,51 +51,38 @@ mamba activate torch-provlae
## Usage
You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in [scripts directory](./scripts/).
-If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "visualize" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.
+If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.
```bash
-# all progressive training steps
-python train.py \
- --dataset shapes3d \
- --mode seq_train \
- --batch_size 100 \
- --num_epochs 15 \
- --learning_rate 5e-4 \
- --beta 15 \
- --z_dim 3 \
- --hidden_dim 64 \
- --fade_in_duration 5000 \
- --optim adamw \
- --image_size 64 \
- --chn_num 3 \
- --output_dir ./output/shapes3d/
-
# training with distributed data parallel
-torchrun --nproc_per_node=2 train_ddp.py \
- --distributed True \
+# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
- --dataset ident3d \
+ --dataset shapes3d \
+ --optim adamw \
--num_ladders 3 \
--batch_size 128 \
- --num_epochs 30 \
+ --num_epochs 15 \
--learning_rate 5e-4 \
- --beta 1 \
+ --beta 8 \
--z_dim 3 \
--coff 0.5 \
- --hidden_dim 64 \
+ --pre_kl \
+ --hidden_dim 32 \
--fade_in_duration 5000 \
- --output_dir ./output/ident3d/ \
- --optim adamw
+ --output_dir ./output/shapes3d/ \
+ --data_path ./data/shapes3d/
```
+- ### Hyper Parameters
+
| Argument | Default | Description |
|----------|---------------|-------------|
-| `dataset` | "shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) |
-| `data_path` | "./data" | Path to dataset storage |
| `z_dim` | 3 | Dimension of latent variables |
| `num_ladders` | 3 | Number of ladders (hierarchies) in pro-VLAE |
| `beta` | 8.0 | β parameter for pro-VLAE |
@@ -103,20 +93,35 @@ torchrun --nproc_per_node=2 train_ddp.py \
| `train_seq` | 1 | Current training sequence number (`indep_train mode` only) |
| `batch_size` | 100 | Batch size |
| `num_epochs` | 1 | Number of epochs |
-| `mode` | "seq_train" | Execution mode ("seq_train", "indep_train", "visualize") |
| `hidden_dim` | 32 | Hidden layer dimension |
| `coff` | 0.5 | Coefficient for KL divergence |
+| `pre_kl` | True | use inactive ladder loss |
+
+
+
+- ### Training Parameters
+
+| Argument | Default | Description |
+|----------|---------------|-------------|
+| `dataset` | "shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) |
+| `data_path` | "./data" | Path to dataset storage |
| `output_dir` | "outputs" | Output directory |
+| `checkpoint_dir` | "checkpoints" | Checkpooints results directory |
+| `recon_dir` | "reconstructions" | Reconstructions results directory |
+| `traverse_dir` | "travesals" | Traversal results directory |
+| `mode` | "seq_train" | Execution mode ("seq_train", "indep_train", "traverse") |
| `compile_mode` | "default" | PyTorch compilation mode |
| `on_cudnn_benchmark` | True | Enable/disable cuDNN benchmark |
| `optim` | "adam" | Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad) |
+| `distributed` | False | enable distributed data parallel |
+| `num_workers` | 4 | Number of workers for data loader |
Mode descriptions:
- `seq_train`: Sequential training from ladder 1 to `num_ladders`
- `indep_train`: Independent training of specified `train_seq` ladder
-- `visualize`: Visualize latent space using trained model (need checkpoints)
+- `traverse`: Visualize latent space using trained model (need checkpoints)
diff --git a/models/__init__.py b/models/__init__.py
deleted file mode 100644
index a57f0e7..0000000
--- a/models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .provlae import ProVLAE
diff --git a/scripts/run_dsprites.sh b/scripts/run_dsprites.sh
deleted file mode 100644
index 37036de..0000000
--- a/scripts/run_dsprites.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/bin/bash
-
-python train.py \
- --mode seq_train \
- --dataset dsprites \
- --num_ladders 3 \
- --batch_size 64 \
- --num_epochs 30 \
- --learning_rate 5e-4 \
- --beta 3 \
- --z_dim 3 \
- --hidden_dim 32 \
- --fade_in_duration 5000 \
- --output_dir ./output/dsprites/ \
- --optim adamw
diff --git a/scripts/run_flowers102.sh b/scripts/run_flowers102.sh
deleted file mode 100644
index d308d50..0000000
--- a/scripts/run_flowers102.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/bin/bash
-
-python train.py \
- --mode seq_train \
- --dataset flowers102 \
- --num_ladders 3 \
- --batch_size 32 \
- --num_epochs 30 \
- --learning_rate 5e-4 \
- --beta 3 \
- --z_dim 3 \
- --coff 0.2 \
- --hidden_dim 32 \
- --fade_in_duration 5000 \
- --output_dir ./output/flowers102/ \
- --optim adamw
diff --git a/scripts/run_ident3d.sh b/scripts/run_ident3d.sh
deleted file mode 100644
index 5a32372..0000000
--- a/scripts/run_ident3d.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/bin/bash
-
-python train.py \
- --mode seq_train \
- --dataset ident3d \
- --num_ladders 3 \
- --batch_size 128 \
- --num_epochs 15 \
- --learning_rate 5e-4 \
- --beta 3 \
- --z_dim 3 \
- --coff 0.3 \
- --hidden_dim 32 \
- --fade_in_duration 5000 \
- --output_dir ./output/ident3d/ \
- --optim adamw \
diff --git a/scripts/run_shape3d.sh b/scripts/run_shape3d.sh
deleted file mode 100644
index b45f16a..0000000
--- a/scripts/run_shape3d.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/bin/bash
-
-python train.py \
- --mode seq_train \
- --dataset shapes3d \
- --num_ladders 3 \
- --batch_size 256 \
- --num_epochs 1 \
- --learning_rate 5e-4 \
- --beta 20 \
- --z_dim 3 \
- --hidden_dim 32 \
- --fade_in_duration 5000 \
- --output_dir ./output/shapes3d/ \
- --optim adamw
diff --git a/dataset.py b/src/dataset.py
similarity index 100%
rename from dataset.py
rename to src/dataset.py
diff --git a/src/ddp_utils.py b/src/ddp_utils.py
new file mode 100644
index 0000000..bc5e3e2
--- /dev/null
+++ b/src/ddp_utils.py
@@ -0,0 +1,86 @@
+import sys
+import os
+
+from loguru import logger
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+
+
+def setup_logger(rank=-1, world_size=1):
+ """Setup logger for distributed training"""
+ config = {
+ "handlers": [
+ {
+ "sink": sys.stdout,
+ "format": (
+ "{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | "
+ "Rank {extra[rank]}/{extra[world_size]} | "
+ "{name}:{line} | "
+ "{message}"
+ ),
+ "level": "DEBUG",
+ "colorize": True,
+ }
+ ]
+ }
+
+ try: # Remove all existing handlers
+ logger.configure(**config)
+ except ValueError:
+ pass
+
+ # Create a new logger instance with rank information
+ return logger.bind(rank=rank, world_size=world_size)
+
+
+def setup_distributed(params):
+ """Initialize distributed training environment with explicit device mapping"""
+ if not params.distributed:
+ return False
+
+ try:
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ params.rank = int(os.environ["RANK"])
+ params.world_size = int(os.environ["WORLD_SIZE"])
+ params.local_rank = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ params.rank = int(os.environ["SLURM_PROCID"])
+ params.local_rank = params.rank % torch.cuda.device_count()
+ params.world_size = int(os.environ["SLURM_NTASKS"])
+ else:
+ raise ValueError("Not running with distributed environment variables set")
+
+ torch.cuda.set_device(params.local_rank)
+ init_method = "env://"
+ backend = params.dist_backend
+ if backend == "nccl" and not torch.cuda.is_available():
+ backend = "gloo"
+
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend=backend,
+ init_method=init_method,
+ world_size=params.world_size,
+ rank=params.rank,
+ )
+ torch.cuda.set_device(params.local_rank)
+ dist.barrier(device_ids=[params.local_rank])
+
+ return True
+ except Exception as e:
+ print(f"Failed to initialize distributed training: {e}")
+ return False
+
+
+def cleanup_distributed():
+ """Clean up distributed training resources safely with device mapping"""
+ if dist.is_initialized():
+ try:
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ dist.barrier(device_ids=[local_rank])
+ dist.destroy_process_group()
+ except Exception as e:
+ print(f"Error during distributed cleanup: {e}")
diff --git a/models/provlae.py b/src/provlae.py
similarity index 100%
rename from models/provlae.py
rename to src/provlae.py
diff --git a/src/scripts/run_dsprites.sh b/src/scripts/run_dsprites.sh
new file mode 100644
index 0000000..027eedb
--- /dev/null
+++ b/src/scripts/run_dsprites.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+torchrun --nproc_per_node=2 --master_port=29502 src/train.py \
+ --distributed \
+ --mode seq_train \
+ --dataset dsprites \
+ --optim adamw \
+ --num_ladders 3 \
+ --batch_size 256 \
+ --num_epochs 30 \
+ --learning_rate 5e-4 \
+ --beta 8 \
+ --z_dim 2 \
+ --coff 0.5 \
+ --pre_kl \
+ --hidden_dim 64 \
+ --fade_in_duration 5000 \
+ --output_dir ./output/dsprites/ \
+ --data_path ./data/dsprites/
+
\ No newline at end of file
diff --git a/scripts/run_dtd.sh b/src/scripts/run_dtd.sh
similarity index 53%
rename from scripts/run_dtd.sh
rename to src/scripts/run_dtd.sh
index ad27f15..62de636 100644
--- a/scripts/run_dtd.sh
+++ b/src/scripts/run_dtd.sh
@@ -1,16 +1,19 @@
#!/bin/bash
-python train.py \
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
--dataset dtd \
+ --optim adamw \
--num_ladders 3 \
- --batch_size 32 \
+ --batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
- --beta 3 \
+ --beta 8 \
--z_dim 3 \
- --coff 0.2 \
+ --coff 0.5 \
+ --pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/dtd/ \
- --optim adamw
+ --data_path ./data/dtd/
diff --git a/scripts/run_fashionmnist.sh b/src/scripts/run_fashionmnist.sh
similarity index 53%
rename from scripts/run_fashionmnist.sh
rename to src/scripts/run_fashionmnist.sh
index 1241e67..8b5ee3b 100644
--- a/scripts/run_fashionmnist.sh
+++ b/src/scripts/run_fashionmnist.sh
@@ -1,15 +1,20 @@
#!/bin/bash
-python train.py \
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
--dataset fashionmnist \
+ --optim adamw \
--num_ladders 3 \
- --batch_size 16 \
+ --batch_size 64 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 3 \
- --z_dim 3 \
+ --z_dim 2 \
+ --coff 0.5 \
+ --pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/fashionmnist/ \
- --optim adamw
+ --data_path ./data/fashionmnist/
+
\ No newline at end of file
diff --git a/src/scripts/run_flowers102.sh b/src/scripts/run_flowers102.sh
new file mode 100644
index 0000000..a29bfb5
--- /dev/null
+++ b/src/scripts/run_flowers102.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
+ --mode seq_train \
+ --dataset flowers102 \
+ --optim adamw \
+ --num_ladders 3 \
+ --batch_size 128 \
+ --num_epochs 30 \
+ --learning_rate 5e-4 \
+ --beta 8 \
+ --z_dim 3 \
+ --coff 0.5 \
+ --pre_kl \
+ --hidden_dim 64 \
+ --fade_in_duration 5000 \
+ --output_dir ./output/flowers102/ \
+ --data_path ./data/flowers102/
diff --git a/scripts/run_ddp.sh b/src/scripts/run_ident3d.sh
similarity index 62%
rename from scripts/run_ddp.sh
rename to src/scripts/run_ident3d.sh
index 8e75605..a856a1c 100644
--- a/scripts/run_ddp.sh
+++ b/src/scripts/run_ident3d.sh
@@ -1,17 +1,19 @@
#!/bin/bash
-torchrun --nproc_per_node=2 train_ddp.py \
- --distributed True \
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
--dataset ident3d \
+ --optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
- --beta 1 \
+ --beta 8 \
--z_dim 3 \
--coff 0.5 \
+ --pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/ident3d/ \
- --optim adamw
+ --data_path ./data/ident3d/
diff --git a/src/scripts/run_imagenet.sh b/src/scripts/run_imagenet.sh
new file mode 100644
index 0000000..cf11d39
--- /dev/null
+++ b/src/scripts/run_imagenet.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
+ --mode seq_train \
+ --dataset imagenet \
+ --optim adamw \
+ --num_ladders 4 \
+ --batch_size 256 \
+ --num_epochs 100 \
+ --learning_rate 5e-4 \
+ --beta 1 \
+ --z_dim 4 \
+ --coff 0.5 \
+ --pre_kl \
+ --hidden_dim 64 \
+ --fade_in_duration 5000 \
+ --output_dir ./output/imagenet/ \
+ --data_path ./data/imagenet/
diff --git a/scripts/run_mnist.sh b/src/scripts/run_mnist.sh
similarity index 52%
rename from scripts/run_mnist.sh
rename to src/scripts/run_mnist.sh
index 510c2da..ada96b4 100644
--- a/scripts/run_mnist.sh
+++ b/src/scripts/run_mnist.sh
@@ -1,15 +1,20 @@
#!/bin/bash
-python train.py \
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
--dataset mnist \
+ --optim adamw \
--num_ladders 3 \
- --batch_size 16 \
+ --batch_size 64 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 3 \
- --z_dim 3 \
+ --z_dim 2 \
+ --coff 0.5 \
+ --pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/mnist/ \
- --optim adamw
\ No newline at end of file
+ --data_path ./data/mnist/
+
\ No newline at end of file
diff --git a/scripts/run_mpi3d.sh b/src/scripts/run_mpi3d.sh
similarity index 52%
rename from scripts/run_mpi3d.sh
rename to src/scripts/run_mpi3d.sh
index 85e8aef..b48a7c2 100644
--- a/scripts/run_mpi3d.sh
+++ b/src/scripts/run_mpi3d.sh
@@ -1,16 +1,19 @@
#!/bin/bash
-python train.py \
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
--mode seq_train \
--dataset mpi3d \
+ --optim adamw \
--num_ladders 3 \
--batch_size 128 \
- --num_epochs 15 \
+ --num_epochs 30 \
--learning_rate 5e-4 \
- --beta 3 \
+ --beta 8 \
--z_dim 3 \
- --coff 0.3 \
- --hidden_dim 32 \
+ --coff 0.5 \
+ --pre_kl \
+ --hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/mpi3d/ \
- --optim adamw \
+ --data_path ./data/mpi3d/
diff --git a/src/scripts/run_shape3d.sh b/src/scripts/run_shape3d.sh
new file mode 100644
index 0000000..f48c6c0
--- /dev/null
+++ b/src/scripts/run_shape3d.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
+ --distributed \
+ --mode seq_train \
+ --dataset shapes3d \
+ --optim adamw \
+ --num_ladders 3 \
+ --batch_size 128 \
+ --num_epochs 15 \
+ --learning_rate 5e-4 \
+ --beta 8 \
+ --z_dim 3 \
+ --coff 0.5 \
+ --pre_kl \
+ --hidden_dim 64 \
+ --fade_in_duration 5000 \
+ --output_dir ./output/shapes3d/ \
+ --data_path ./data/shapes3d/
diff --git a/train_ddp.py b/src/train.py
similarity index 75%
rename from train_ddp.py
rename to src/train.py
index 3e4390f..4f0528d 100644
--- a/train_ddp.py
+++ b/src/train.py
@@ -1,8 +1,6 @@
import argparse
-import datetime
import os
import sys
-import time
from dataclasses import dataclass, field
import imageio.v3 as imageio
@@ -20,13 +18,23 @@
from tqdm import tqdm
from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D
-from models import ProVLAE
+from provlae import ProVLAE
+from ddp_utils import setup_logger, setup_distributed, cleanup_distributed
+from utils import exec_time, add_dataclass_args
+
+
+@dataclass
+class OptimizerParameters:
+ betas: tuple = field(default=(0.9, 0.999))
+ eps: float = field(default=1e-08)
+ weight_decay: float = field(default=0)
+
+ momentum: float = field(default=0) # sgd, madgrad
+ dampening: float = field(default=0) # sgd
@dataclass
class HyperParameters:
- dataset: str = field(default="shapes3d")
- data_path: str = field(default="./data")
z_dim: int = field(default=3)
num_ladders: int = field(default=3)
beta: float = field(default=8.0)
@@ -34,122 +42,47 @@ class HyperParameters:
fade_in_duration: int = field(default=5000)
image_size: int = field(default=64)
chn_num: int = field(default=3)
- train_seq: int = field(default=1)
batch_size: int = field(default=100)
num_epochs: int = field(default=1)
mode: str = field(default="seq_train")
+ train_seq: int = field(default=1) # progress stage
hidden_dim: int = field(default=32)
coff: float = field(default=0.5)
- output_dir: str = field(default="outputs")
+ pre_kl: bool = field(default=True)
+
- # pytorch optimization
+@dataclass
+class TrainingParameters:
+ dataset: str = field(default="shapes3d")
+ data_path: str = field(default="./data")
+ num_workers: int = field(default=4) # data loader
+
+ # output dirs
+ output_dir: str = field(default="output") # results dir
+ checkpoint_dir: str = field(default="checkpoints")
+ recon_dir: str = field(default="reconstructions")
+ traverse_dir: str = field(default="traversals")
+
+ # PyTorch optimization
compile_mode: str = field(default="default") # or max-autotune-no-cudagraphs
on_cudnn_benchmark: bool = field(default=True)
optim: str = field(default="adam")
- # distributed training parameters
+ # Distributed training parameters
distributed: bool = field(default=False)
local_rank: int = field(default=-1)
world_size: int = field(default=1)
dist_backend: str = field(default="nccl")
dist_url: str = field(default="env://")
- @property
- def checkpoint_path(self):
- return os.path.join(self.output_dir, f"checkpoints/model_seq{self.train_seq}.pt")
-
- @property
- def recon_path(self):
- return os.path.join(self.output_dir, f"reconstructions/recon_seq{self.train_seq}.png")
-
- @property
- def traverse_path(self):
- return os.path.join(self.output_dir, f"traversals/traverse_seq{self.train_seq}.gif")
-
- @classmethod
- def from_args(cls):
- parser = argparse.ArgumentParser()
- for field_info in cls.__dataclass_fields__.values():
- parser.add_argument(f"--{field_info.name}", type=field_info.type, default=field_info.default)
- return cls(**vars(parser.parse_args()))
-
-
-def setup_logger(rank=-1, world_size=1):
- """Setup logger for distributed training"""
- config = {
- "handlers": [
- {
- "sink": sys.stdout,
- "format": (
- "{time:YYYY-MM-DD HH:mm:ss} | "
- "{level: <8} | "
- "Rank {extra[rank]}/{extra[world_size]} | "
- "{name}:{line} | "
- "{message}"
- ),
- "level": "DEBUG",
- "colorize": True,
- }
- ]
- }
-
- try: # Remove all existing handlers
- logger.configure(**config)
- except ValueError:
- pass
- # Create a new logger instance with rank information
- return logger.bind(rank=rank, world_size=world_size)
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ add_dataclass_args(parser, HyperParameters)
+ add_dataclass_args(parser, OptimizerParameters)
+ add_dataclass_args(parser, TrainingParameters)
-
-def setup_distributed(params):
- """Initialize distributed training environment with explicit device mapping"""
- if not params.distributed:
- return False
-
- try:
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
- params.rank = int(os.environ["RANK"])
- params.world_size = int(os.environ["WORLD_SIZE"])
- params.local_rank = int(os.environ["LOCAL_RANK"])
- elif "SLURM_PROCID" in os.environ:
- params.rank = int(os.environ["SLURM_PROCID"])
- params.local_rank = params.rank % torch.cuda.device_count()
- params.world_size = int(os.environ["SLURM_NTASKS"])
- else:
- raise ValueError("Not running with distributed environment variables set")
-
- torch.cuda.set_device(params.local_rank)
- init_method = "env://"
- backend = params.dist_backend
- if backend == "nccl" and not torch.cuda.is_available():
- backend = "gloo"
-
- if not dist.is_initialized():
- dist.init_process_group(
- backend=backend,
- init_method=init_method,
- world_size=params.world_size,
- rank=params.rank,
- )
- torch.cuda.set_device(params.local_rank)
- dist.barrier(device_ids=[params.local_rank])
-
- return True
- except Exception as e:
- print(f"Failed to initialize distributed training: {e}")
- return False
-
-
-def cleanup_distributed():
- """Clean up distributed training resources safely with device mapping"""
- if dist.is_initialized():
- try:
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
- dist.barrier(device_ids=[local_rank])
- dist.destroy_process_group()
- except Exception as e:
- print(f"Error during distributed cleanup: {e}")
+ return parser.parse_args()
def get_dataset(params, logger):
@@ -197,7 +130,7 @@ def get_dataset(params, logger):
train_loader.dataset,
batch_size=params.batch_size,
sampler=train_sampler,
- num_workers=4,
+ num_workers=params.num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=True,
@@ -215,31 +148,6 @@ def get_dataset(params, logger):
raise
-def exec_time(func):
- """Decorates a function to measure its execution time in hours and minutes."""
-
- def wrapper(*args, **kwargs):
- start_time = time.time()
- result = func(*args, **kwargs)
- end_time = time.time()
- execution_time = end_time - start_time
-
- logger = kwargs.get("logger") # Get logger from kwargs
- if not logger: # Find logger in positional arguments
- for arg in args:
- if isinstance(arg, type(setup_logger())):
- logger = arg
- break
-
- if logger:
- logger.success(
- f"Training completed ({int(execution_time // 3600)}h {int((execution_time % 3600) // 60)}min)"
- )
- return result
-
- return wrapper
-
-
def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger):
"""Load a model checkpoint with proper device management."""
try:
@@ -302,15 +210,13 @@ def create_latent_traversal(model, data_loader, save_path, device, params):
model.fade_in = 1.0
with torch.no_grad():
- # Get a single batch of images
- inputs, _ = next(iter(data_loader))
+ inputs, _ = next(iter(data_loader)) # Get a single batch of images
inputs = inputs[0:1].to(device)
# Get latent representations
with torch.amp.autocast(device_type="cuda", enabled=False):
latent_vars = [z[0] for z in model.inference(inputs)]
- # Traverse values
traverse_range = torch.linspace(-1.5, 1.5, 15).to(device)
# Image layout parameters
@@ -423,7 +329,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
model.train()
global_step = 0
- logger.info("Start training.")
+ logger.info(f"Start training [progress {params.train_seq}]")
for epoch in range(params.num_epochs):
if params.distributed:
data_loader.sampler.set_epoch(epoch)
@@ -459,13 +365,20 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
global_step += 1
# Save checkpoints and visualizations only on main process
+ checkpoint_path = os.path.join(params.output_dir, params.checkpoint_dir, f"model_seq{params.train_seq}.pt")
if params.local_rank == 0 or not params.distributed:
- os.makedirs(os.path.dirname(params.checkpoint_path), exist_ok=True)
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
model.eval()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=autocast_dtype):
- save_reconstruction(inputs, x_recon, params.recon_path)
- create_latent_traversal(model, data_loader, params.traverse_path, device, params)
+ recon_path = os.path.join(params.output_dir, params.recon_dir, f"recon_seq{params.train_seq}.png")
+ traverse_path = os.path.join(
+ params.output_dir, params.traverse_dir, f"traverse_seq{params.train_seq}.gif"
+ )
+
+ save_reconstruction(inputs, x_recon, recon_path)
+ create_latent_traversal(model, data_loader, traverse_path, device, params)
+
model.train()
checkpoint = {
@@ -475,7 +388,7 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
"scaler_state_dict": scaler.state_dict() if scaler is not None else None,
"loss": loss.item(),
}
- torch.save(checkpoint, params.checkpoint_path)
+ torch.save(checkpoint, checkpoint_path)
if (epoch + 1) % 5 == 0:
logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}")
@@ -483,26 +396,58 @@ def train_model(model, data_loader, optimizer, params, device, logger, scaler=No
def get_optimizer(model, params):
"""Get the optimizer based on the parameter settings"""
- optimizers = {
- "adam": optim.Adam,
- "adamw": optim.AdamW,
- "sgd": optim.SGD,
- "lamb": jettify_optim.Lamb,
- "diffgrad": jettify_optim.DiffGrad,
- "madgrad": jettify_optim.MADGRAD,
+ optimizer_params = {
+ "params": model.parameters(),
+ "lr": params.learning_rate,
+ }
+
+ # Adam, Lamb, DiffGrad
+ extra_args_common = {
+ "betas": getattr(params, "betas", (0.9, 0.999)),
+ "eps": getattr(params, "eps", 1e-8),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ }
+
+ extra_args_adamw = {
+ "betas": getattr(params, "betas", (0.9, 0.999)),
+ "eps": getattr(params, "eps", 1e-8),
+ "weight_decay": getattr(params, "weight_decay", 0.01),
+ }
+
+ # SGD
+ extra_args_sgd = {
+ "momentum": getattr(params, "momentum", 0),
+ "dampening": getattr(params, "dampening", 0),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ "nesterov": getattr(params, "nesterov", False),
+ }
+
+ # MADGRAD
+ extra_args_madgrad = {
+ "momentum": getattr(params, "momentum", 0.9),
+ "weight_decay": getattr(params, "weight_decay", 0),
+ "eps": getattr(params, "eps", 1e-6),
}
- optimizer = optimizers.get(params.optim.lower())
+ optimizers = {
+ "adam": (optim.Adam, extra_args_common),
+ "adamw": (optim.AdamW, extra_args_adamw),
+ "sgd": (optim.SGD, extra_args_sgd),
+ "lamb": (jettify_optim.Lamb, extra_args_common),
+ "diffgrad": (jettify_optim.DiffGrad, extra_args_common),
+ "madgrad": (jettify_optim.MADGRAD, extra_args_madgrad),
+ }
- if optimizer is None:
- optimizer = optimizers.get("adam")
- logger.warning(f"unsupported optimizer {params.optim}, use Adam optimizer.")
+ optimizer_cls, extra_args = optimizers.get(params.optim.lower(), (optim.Adam, extra_args_common))
+ if params.optim.lower() not in optimizers:
+ logger.warning(f"Unsupported optimizer '{params.optim}', using 'Adam' optimizer instead.")
+ optimizer = optimizer_cls(**optimizer_params, **extra_args)
- return optimizer(model.parameters(), lr=params.learning_rate)
+ return optimizer
def main():
- params = HyperParameters.from_args()
+ params = parse_arguments()
try:
# Setup distributed training
@@ -514,7 +459,6 @@ def main():
device = torch.device(f"cuda:{params.local_rank}" if is_distributed else "cuda")
logger = setup_logger(rank, world_size)
- # GPU config
torch.set_float32_matmul_precision("high")
if params.on_cudnn_benchmark:
torch.backends.cudnn.benchmark = True
@@ -554,6 +498,7 @@ def main():
num_ladders=params.num_ladders,
hidden_dim=params.hidden_dim,
coff=params.coff,
+ pre_kl=params.pre_kl,
).to(device)
if is_distributed:
@@ -570,12 +515,13 @@ def main():
optimizer = get_optimizer(model, params)
if not is_distributed:
model = torch.compile(model, mode=params.compile_mode)
+ logger.debug("model compiled")
# Training mode selection
if params.mode == "seq_train":
if rank == 0:
logger.opt(colors=True).info(
- f"✓ Mode: sequential execution [progress 1 >> {params.num_ladders}]"
+ f"✅ Mode: sequential execution [progress 1 >> {params.num_ladders}]"
)
for i in range(1, params.num_ladders + 1):
@@ -593,7 +539,9 @@ def main():
# Load checkpoint if needed
if params.train_seq >= 2:
- prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt")
+ prev_checkpoint = os.path.join(
+ params.output_dir, params.checkpoint_dir, f"model_seq{params.train_seq-1}.pt"
+ )
if os.path.exists(prev_checkpoint):
model, optimizer, scaler = load_checkpoint(
model=model,
@@ -625,9 +573,10 @@ def main():
dist.barrier()
elif params.mode == "indep_train":
+ logger.info(f"Current trainig progress >> {params.train_seq}")
if rank == 0:
logger.opt(colors=True).info(
- f"✓ Mode: independent execution [progress {params.train_seq}]"
+ f"✅ Mode: independent execution [progress {params.train_seq}]"
)
if is_distributed:
@@ -636,7 +585,9 @@ def main():
# Load checkpoint if needed
if params.train_seq >= 2:
- prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt")
+ prev_checkpoint = os.path.join(
+ params.output_dir, params.checkpoint_dir, f"model_seq{params.train_seq-1}.pt"
+ )
if os.path.exists(prev_checkpoint):
model, optimizer, scaler = load_checkpoint(
model=model,
@@ -667,21 +618,46 @@ def main():
torch.cuda.synchronize()
dist.barrier()
+ elif params.mode == "traverse":
+ logger.opt(colors=True).info(f"✅ Mode: traverse execution [progress 1 {params.num_ladders}]")
+ try:
+ model, optimizer, scaler = load_checkpoint(
+ model=model,
+ optimizer=optimizer,
+ scaler=scaler,
+ checkpoint_path=os.path.join(
+ params.output_dir, params.checkpoint_dir, f"model_seq{params.train_seq}.pt"
+ ),
+ device=device,
+ logger=logger,
+ )
+ except Exception as e:
+ logger.error(f"Load checkpoint failed: {str(e)}")
+
+ traverse_path = os.path.join(params.output_dir, params.traverse_dir, f"traverse_seq{params.train_seq}.gif")
+ create_latent_traversal(model, test_loader, traverse_path, device, params)
+ logger.success("Traverse compelted")
else:
logger.error(f"Unsupported mode: {params.mode}, use 'seq_train' or 'indep_train'")
return
+ except KeyboardInterrupt as e:
+ logger.opt(colors=True).error("Keyboard interupt")
+
except Exception as e:
- logger.error(f"Training failed: {str(e)}")
- raise
+ logger.opt(colors=True).exception(f"Training failed: {str(e)}")
+
finally:
- try:
- if is_distributed:
- torch.cuda.synchronize()
- cleanup_distributed()
- logger.info("Distributed resources cleaned up successfully")
- except Exception as e:
- logger.error(f"Error during cleanup: {e}")
+ if not is_distributed:
+ logger.info("no resources clean up (is_distributed=False)")
+ else:
+ try:
+ if is_distributed:
+ torch.cuda.synchronize()
+ cleanup_distributed()
+ logger.info("Distributed resources cleaned up successfully")
+ except Exception as e:
+ logger.opt(colors=True).exception(f"Error during cleanup: {e}")
if __name__ == "__main__":
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..b7fdf5e
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,67 @@
+import time
+import argparse
+from typing import Any
+
+from ddp_utils import setup_logger
+
+
+def exec_time(func):
+ """Decorates a function to measure its execution time in hours and minutes."""
+
+ def wrapper(*args, **kwargs):
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ end_time = time.time()
+ execution_time = end_time - start_time
+
+ logger = kwargs.get("logger") # Get logger from kwargs
+ if not logger: # Find logger in positional arguments
+ for arg in args:
+ if isinstance(arg, type(setup_logger())):
+ logger = arg
+ break
+
+ if logger:
+ logger.success(
+ f"Training completed ({int(execution_time // 3600)}h {int((execution_time % 3600) // 60)}min)"
+ )
+ return result
+
+ return wrapper
+
+
+def add_dataclass_args(parser: argparse.ArgumentParser, dataclass_type: Any):
+ for field_info in dataclass_type.__dataclass_fields__.values():
+ # Skip properties (those methods marked with @property)
+ if isinstance(field_info.type, property):
+ continue
+
+ # bool type
+ if field_info.type is bool:
+ parser.add_argument(
+ f"--{field_info.name}",
+ action="store_true" if not field_info.default else "store_false",
+ help=f"Set {field_info.name} to {not field_info.default}",
+ )
+ # tuple, list, float
+ elif isinstance(field_info.default, tuple):
+ parser.add_argument(
+ f"--{field_info.name}",
+ type=lambda x: tuple(map(float, x.split(","))),
+ default=field_info.default,
+ help=f"Set {field_info.name} to a tuple of floats (e.g., 0.9,0.999)",
+ )
+ elif isinstance(field_info.default, list):
+ parser.add_argument(
+ f"--{field_info.name}",
+ type=lambda x: list(map(float, x.split(","))),
+ default=field_info.default,
+ help=f"Set {field_info.name} to a list of floats (e.g., 0.1,0.2,0.3)",
+ )
+ else:
+ parser.add_argument(
+ f"--{field_info.name}",
+ type=field_info.type,
+ default=field_info.default,
+ help=f"Set {field_info.name} to a value of type {field_info.type.__name__}",
+ )
diff --git a/train.py b/train.py
deleted file mode 100644
index a465e6d..0000000
--- a/train.py
+++ /dev/null
@@ -1,434 +0,0 @@
-import argparse
-import os
-import time
-from dataclasses import dataclass, field
-
-import imageio.v3 as imageio
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torch.optim as optim
-import torch_optimizer as jettify_optim
-import torchvision
-from loguru import logger
-from PIL import Image, ImageDraw, ImageFont
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
-from tqdm import tqdm
-
-from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D
-from models import ProVLAE
-
-
-@dataclass
-class HyperParameters:
- dataset: str = field(default="shapes3d")
- data_path: str = field(default="./data")
- z_dim: int = field(default=3)
- num_ladders: int = field(default=3)
- beta: float = field(default=8.0)
- learning_rate: float = field(default=5e-4)
- fade_in_duration: int = field(default=5000)
- image_size: int = field(default=64)
- chn_num: int = field(default=3)
- train_seq: int = field(default=1)
- batch_size: int = field(default=100)
- num_epochs: int = field(default=1)
- mode: str = field(default="seq_train")
- hidden_dim: int = field(default=32)
- coff: float = field(default=0.5)
- output_dir: str = field(default="outputs")
-
- # pytorch optimization
- compile_mode: str = field(default="default") # or max-autotune-no-cudagraphs
- on_cudnn_benchmark: bool = field(default=True)
- optim: str = field(default="adam")
-
- @property
- def checkpoint_path(self):
- return os.path.join(self.output_dir, f"checkpoints/model_seq{self.train_seq}.pt")
-
- @property
- def recon_path(self):
- return os.path.join(self.output_dir, f"reconstructions/recon_seq{self.train_seq}.png")
-
- @property
- def traverse_path(self):
- return os.path.join(self.output_dir, f"traversals/traverse_seq{self.train_seq}.gif")
-
- @classmethod
- def from_args(cls):
- parser = argparse.ArgumentParser()
- for field_info in cls.__dataclass_fields__.values():
- parser.add_argument(f"--{field_info.name}", type=field_info.type, default=field_info.default)
- return cls(**vars(parser.parse_args()))
-
-
-def get_dataset(params):
- """Load the dataset and return the data loader"""
- dataset_classes = {
- "mnist": MNIST,
- "fashionmnist": FashionMNIST,
- "shapes3d": Shapes3D,
- "dsprites": DSprites,
- "celeba": CelebA,
- "flowers102": Flowers102,
- "dtd": DTD,
- "imagenet": ImageNet,
- "mpi3d": MPI3D,
- "ident3d": Ident3D,
- }
-
- if params.dataset not in dataset_classes:
- raise ValueError(f"Unknown dataset: {params.dataset}. " f"Available datasets: {list(dataset_classes.keys())}")
-
- dataset_class = dataset_classes[params.dataset]
- if params.dataset == "mpi3d":
- # mpi3d variants: toy, real
- variant = getattr(params, "mpi3d_variant", "toy")
- dataset = dataset_class(
- root=params.data_path,
- variant=variant,
- batch_size=params.batch_size,
- num_workers=4,
- )
- else: # other dataset
- dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4)
-
- config = dataset.get_config()
- params.chn_num = config.chn_num
- params.image_size = config.image_size
-
- logger.success("Dataset loaded.")
- return dataset.get_data_loader()
-
-
-def exec_time(func):
- """Decorates a function to measure its execution time in hours and minutes."""
-
- def wrapper(*args, **kwargs):
- start_time = time.time()
- result = func(*args, **kwargs)
- end_time = time.time()
- execution_time = end_time - start_time
-
- logger.success(f"Training completed ({int(execution_time // 3600)}h {int((execution_time % 3600) // 60)}min)")
- return result
-
- return wrapper
-
-
-def load_checkpoint(model, optimizer, scaler, checkpoint_path):
- """
- Load a model checkpoint to resume training or run further inference.
- """
- torch.serialization.add_safe_globals([set])
- checkpoint = torch.load(
- checkpoint_path,
- map_location=torch.device("cpu" if not torch.cuda.is_available() else "cuda"),
- weights_only=True,
- )
-
- model.load_state_dict(checkpoint["model_state_dict"], strict=False)
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- if scaler is not None and "scaler_state_dict" in checkpoint:
- scaler.load_state_dict(checkpoint["scaler_state_dict"])
-
- logger.info(
- f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})"
- )
-
- return model, optimizer, scaler
-
-
-def save_reconstruction(inputs, reconstructions, save_path):
- """Save a grid of original and reconstructed images"""
- batch_size = min(8, inputs.shape[0])
- inputs = inputs[:batch_size].float()
- reconstructions = reconstructions[:batch_size].float()
- comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]])
-
- # Denormalize and convert to numpy
- images = comparison.cpu().detach()
- images = torch.clamp(images, 0, 1)
- grid = torchvision.utils.make_grid(images, nrow=batch_size)
- image = grid.permute(1, 2, 0).numpy()
-
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- imageio.imwrite(save_path, (image * 255).astype("uint8"))
-
-
-def create_latent_traversal(model, data_loader, save_path, device, params):
- """Create and save organized latent traversal GIF with optimized layout"""
- model.eval()
- model.fade_in = 1.0
- with torch.no_grad():
- # Get a single batch of images
- inputs, _ = next(iter(data_loader))
- inputs = inputs[0:1].to(device)
-
- # Get latent representations
- with torch.amp.autocast(device_type="cuda", enabled=False):
- latent_vars = [z[0] for z in model.inference(inputs)]
-
- # Traverse values
- traverse_range = torch.linspace(-1.5, 1.5, 15).to(device)
-
- # Image layout parameters
- img_size = 96 # Base image size
- padding = 1 # Reduced padding between images
- label_margin = 1 # Margin for labels inside images
- font_size = 7 # Smaller font size for better fit
-
- try:
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
- except:
- font = ImageFont.load_default()
-
- frames = []
- for t_idx in range(len(traverse_range)):
- current_images = []
-
- # Generate images for each ladder and dimension
- for ladder_idx in range(len(latent_vars) - 1, -1, -1):
- for dim in range(latent_vars[ladder_idx].shape[1]):
- z_mod = [v.clone() for v in latent_vars]
- z_mod[ladder_idx][0, dim] = traverse_range[t_idx]
-
- with torch.amp.autocast(device_type="cuda", enabled=False):
- gen_img = model.generate(z_mod)
- img = gen_img[0].cpu().float()
- img = torch.clamp(img, 0, 1)
-
- # Resize image if needed
- if img.shape[-1] != img_size:
- img = F.interpolate(
- img.unsqueeze(0),
- size=img_size,
- mode="bilinear",
- align_corners=False,
- ).squeeze(0)
-
- # Handle both single-channel and multi-channel images
- if img.shape[0] == 1:
- # Single channel (grayscale) - repeat to create RGB
- img = img.repeat(3, 1, 1)
-
- # Convert to PIL Image for adding text
- img_array = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
- img_pil = Image.fromarray(img_array)
- draw = ImageDraw.Draw(img_pil)
-
- # Add label inside the image
- label = f"L{len(latent_vars)-ladder_idx} z{dim+1}"
- # draw black and white text to create border effect
- for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
- draw.text(
- (label_margin + offset[0], label_margin + offset[1]),
- label,
- (0, 0, 0),
- font=font,
- )
- draw.text((label_margin, label_margin), label, (255, 255, 255), font=font)
-
- # Add value label to bottom-left image
- if ladder_idx == 0 and dim == 0:
- value_label = f"v = {traverse_range[t_idx].item():.2f}"
- for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
- draw.text(
- (
- label_margin + offset[0],
- img_size - font_size - label_margin + offset[1],
- ),
- value_label,
- (0, 0, 0),
- font=font,
- )
- draw.text(
- (label_margin, img_size - font_size - label_margin),
- value_label,
- (255, 255, 255),
- font=font,
- )
-
- # Convert back to tensor
- img_tensor = torch.tensor(np.array(img_pil)).float() / 255.0
- img_tensor = img_tensor.permute(2, 0, 1)
- current_images.append(img_tensor)
-
- # Convert to tensor and create grid
- current_images = torch.stack(current_images)
-
- # Create grid with minimal padding
- grid = torchvision.utils.make_grid(current_images, nrow=params.z_dim, padding=padding, normalize=True)
-
- grid_array = (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
- frames.append(grid_array)
-
- # Save GIF with infinite loop
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
-
- # duration=200 means 5 FPS, loop=0 means infinite loop
- imageio.imwrite(save_path, frames, duration=200, loop=0, format="GIF", optimize=False)
-
-
-@exec_time
-def train_model(model, data_loader, optimizer, params, device, scaler=None, autocast_dtype=torch.float16):
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
- model.train()
- global_step = 0
-
- logger.info("Start training.")
- for epoch in range(params.num_epochs):
- with tqdm(
- enumerate(data_loader),
- desc=f"Current epoch [{epoch + 1}/{params.num_epochs}]",
- leave=False,
- total=len(data_loader),
- ) as pbar:
- for batch_idx, (inputs, _) in pbar:
- inputs = inputs.to(device, non_blocking=True)
-
- # Forward pass with autocast
- with torch.amp.autocast(device_type="cuda", dtype=autocast_dtype):
- x_recon, loss, latent_loss, recon_loss = model(inputs, step=global_step)
-
- # Backward pass with appropriate scaling
- optimizer.zero_grad()
- if scaler is not None:
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- else:
- loss.backward()
- optimizer.step()
-
- pbar.set_postfix(
- total_loss=f"{loss.item():.2f}",
- latent_loss=f"{latent_loss:.2f}",
- recon_loss=f"{recon_loss:.2f}",
- )
- global_step += 1
-
- # Save checkpoints and visualizations
- os.makedirs(os.path.dirname(params.checkpoint_path), exist_ok=True)
-
- # Model evaluation for visualizations
- model.eval()
- with torch.no_grad(), torch.amp.autocast("cuda", dtype=autocast_dtype):
- save_reconstruction(inputs, x_recon, params.recon_path)
- create_latent_traversal(model, data_loader, params.traverse_path, device, params)
- model.train()
-
- checkpoint = {
- "epoch": epoch + 1,
- "model_state_dict": model.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "scaler_state_dict": scaler.state_dict() if scaler is not None else None,
- "loss": loss.item(),
- }
- torch.save(checkpoint, params.checkpoint_path)
-
- if (epoch + 1) % 5 == 0:
- logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}")
-
-
-def get_optimizer(model, params):
- optimizers = {
- "adam": optim.Adam,
- "adamw": optim.AdamW,
- "sgd": optim.SGD,
- "lamb": jettify_optim.Lamb,
- "diffgrad": jettify_optim.DiffGrad,
- "madgrad": jettify_optim.MADGRAD,
- }
-
- optimizer = optimizers.get(params.optim.lower())
-
- if optimizer is None:
- optimizer = optimizers.get("adam")
- logger.warning(f"unsupported optimizer {params.optim}, use Adam optimizer.")
-
- return optimizer(model.parameters(), lr=params.learning_rate)
-
-
-def main():
- # Setup
- params = HyperParameters.from_args()
-
- # gpu config
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- torch.set_float32_matmul_precision("high")
- if params.on_cudnn_benchmark:
- torch.backends.cudnn.benchmark = True
-
- if torch.cuda.get_device_capability()[0] >= 8:
- if not torch.cuda.is_bf16_supported():
- logger.warning("BF16 is not supported, falling back to FP16")
- autocast_dtype = torch.float16
- scaler = torch.amp.GradScaler()
- else: # BF16
- autocast_dtype = torch.bfloat16
- scaler = None
- logger.debug("Using BF16 mixed precision")
- else: # FP16
- autocast_dtype = torch.float16
- scaler = torch.amp.GradScaler()
- logger.debug("Using FP16 mixed precision with gradient scaling")
-
- os.makedirs(params.output_dir, exist_ok=True)
- train_loader, test_loader = get_dataset(params)
-
- # Initialize model
- model = ProVLAE(
- z_dim=params.z_dim,
- beta=params.beta,
- learning_rate=params.learning_rate,
- fade_in_duration=params.fade_in_duration,
- chn_num=params.chn_num,
- train_seq=params.train_seq,
- image_size=params.image_size,
- num_ladders=params.num_ladders,
- hidden_dim=params.hidden_dim,
- coff=params.coff,
- ).to(device)
-
- optimizer = get_optimizer(model, params)
- model = torch.compile(model, mode=params.compile_mode)
-
- # Train model or visualize traverse
- if params.mode == "seq_train":
- logger.opt(colors=True).info(f"✓ Mode: sequential execution [progress 1 >> {params.num_ladders}]")
- for i in range(1, params.num_ladders + 1):
- params.train_seq, model.train_seq = i, i
- if params.train_seq >= 2:
- prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt")
- if os.path.exists(prev_checkpoint):
- model, optimizer, scaler = load_checkpoint(model, optimizer, scaler, prev_checkpoint)
- train_model(model, train_loader, optimizer, params, device, scaler, autocast_dtype)
-
- elif params.mode == "indep_train":
- logger.opt(colors=True).info(f"✓ Mode: independent execution [progress {params.train_seq}]")
- if params.train_seq >= 2:
- prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt")
- if os.path.exists(prev_checkpoint):
- model, optimizer, scaler = load_checkpoint(model, optimizer, scaler, prev_checkpoint)
-
- train_model(model, train_loader, optimizer, params, device, scaler, autocast_dtype)
-
- elif params.mode == "visualize":
- logger.opt(colors=True).info(f"✓ Mode: visualize latent traversing [progress {params.train_seq}]")
- current_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq}.pt")
- if os.path.exists(current_checkpoint):
- model, _, _ = load_checkpoint(model, optimizer, scaler, current_checkpoint)
- create_latent_traversal(model, test_loader, params.traverse_path, device, params)
- logger.success("Latent traversal visualization saved.")
-
- else:
- logger.error(f"unsupported mode: {params.mode}, use 'train' or 'visualize'.")
-
-
-if __name__ == "__main__":
- main()