This is a PyTorch implementation of the paper PROGRESSIVE LEARNING AND DISENTANGLEMENT OF HIERARCHICAL REPRESENTATIONS by Zhiyuan et al, ICLR 2020. The official code for proVLAE, implemented in TensorFlow, is available here.
⬆︎ Visualization of results when traversing the latent space of pytorch-proVLAE trained on four datasets: 3D Shapes (top-left), MNIST (top-right), 3DIdent (bottom-left), and MPI3D (bottom-right). The results for the last two datasets are preliminary, with hyperparameter tuning still in progress.
Figure 1: Progressive learning of hierarchical representations. White blocks and solid lines are VAE models at the current progression. α is a fade-in coefficient for blending in the new network component. Gray circles and dash line represents (optional) constraining of the future latent variables.
⬆︎ The figure from Zhiyuan et al. (ICLR 2020) illustrates the ladder architecture of the VAE and the progressive learning of hierarchical representations.
We recommend using mamba (via miniforge) for faster installation of dependencies, but you can also use conda.
git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
cd pytorch-proVLAE
mamba env create -f env.yaml # or conda
mamba activate torch-provlae
You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in scripts directory. If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.
# training with distributed data parallel
# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset shapes3d \
--optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 15 \
--learning_rate 5e-4 \
--beta 8 \
--z_dim 3 \
--coff 0.5 \
--pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/shapes3d/ \
--data_path ./data
Argument | Default | Description |
---|---|---|
z_dim |
3 | Dimension of latent variables |
num_ladders |
3 | Number of ladders (hierarchies) in pro-VLAE |
beta |
8.0 | β parameter for pro-VLAE |
learning_rate |
5e-4 | Learning rate |
fade_in_duration |
5000 | Number of steps for fade-in period |
image_size |
64 | Input image size |
chn_num |
3 | Number of input image channels |
train_seq |
1 | Current training sequence number (indep_train mode only) |
batch_size |
100 | Batch size |
num_epochs |
1 | Number of epochs |
hidden_dim |
32 | Hidden layer dimension |
coff |
0.5 | Coefficient for KL divergence |
pre_kl |
True | Use inactive ladder loss |
use_kl_annealing |
False | Enable KL annealing (see haofuml/cyclical_annealing) |
kl_annealing_mode |
"linear" | KL cycle annealing mode (linear, sigmoid, cosine) |
cycle_period |
4 | The number of annealing |
max_kl_weight |
1.0 | Max KL weight |
min_kl_weight |
0.1 | Min KL weight |
Argument | Default | Description |
---|---|---|
dataset |
"shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) |
data_path |
"./data" | Path to dataset storage |
output_dir |
"outputs" | Output directory |
checkpoint_dir |
"checkpoints" | Checkpooints results directory |
recon_dir |
"reconstructions" | Reconstructions results directory |
traverse_dir |
"travesals" | Traversal results directory |
mode |
"seq_train" | Execution mode ("seq_train", "indep_train", "traverse") |
compile_mode |
"default" | PyTorch compilation mode |
on_cudnn_benchmark |
True | Enable/disable cuDNN benchmark |
optim |
"adam" | Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad) |
distributed |
False | enable distributed data parallel |
num_workers |
4 | Number of workers for data loader |
use_wandb |
False | Enable wandb tracking and logging |
wandb_project |
"provlae" | Wandb project name |
Mode descriptions:
seq_train
: Sequential training from ladder 1 tonum_ladders
indep_train
: Independent training of specifiedtrain_seq
laddertraverse
: Visualize latent space using trained model (need checkpoints)
-
Fast training: performance tuning referred to in PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA.
-
Optimizer: DiffGrad, Lamb, MADGRAD is implemented by jettify/pytorch-optimizer, other optimizer based torch.optim package.
Here is a lists of datasets used in the original proVLAE paper, along with additional disentanglement datasets. The datasets are automatically downloaded and preprocessed when you specify the dataset name in the --dataset
argument except for ImageNet.
- MNIST:
mnist
- Disentanglement testing Sprites dataset (dSprites):
dsprites
- 3D Shapes:
shapes3d
- Large-scale CelebFaces Attributes (CelebA):
celeba
- MPI3D Disentanglement Datasets:
mpi3d
- 3DIdent:
ident3d
- Fashion-MNIST:
fashionmnist
- Describable Textures Dataset (DTD):
dtd
- 102 Category Flower Dataset:
flowers102
- ImageNet:
imagenet
Hyperparameter optimization (beta, coff, fade-in duration, learning rates) and implementation of disentanglement metrics (MIG for detecting factor splitting, MIG-sup for factor entanglement) are currently under development. Benchmark results will be provided in future updates.
This repository is licensed under the MIT License - see the LICENSE file for details. This follows the licensing of the original implementation license by Zhiyuan.
*This repository is a contribution to AIST (National Institute of Advanced Industrial Science and Technology) project.
Human Informatics and Interaction Research Institute, Neuronrehabilitation Research Group
Shosuke Suzuki, Ryusuke Hayashi