Skip to content

suzuki-2001/pytorch-proVLAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-proVLAE

MIT LICENSE Format Code Validate Mamba Environment


This is a PyTorch implementation of the paper PROGRESSIVE LEARNING AND DISENTANGLEMENT OF HIERARCHICAL REPRESENTATIONS by Zhiyuan et al, ICLR 2020. The official code for proVLAE, implemented in TensorFlow, is available here.


⬆︎ Visualization of results when traversing the latent space of pytorch-proVLAE trained on four datasets: 3D Shapes (top-left), MNIST (top-right), 3DIdent (bottom-left), and MPI3D (bottom-right). The results for the last two datasets are preliminary, with hyperparameter tuning still in progress.

 

Model Architecture

Figure 1: Progressive learning of hierarchical representations. White blocks and solid lines are VAE models at the current progression. α is a fade-in coefficient for blending in the new network component. Gray circles and dash line represents (optional) constraining of the future latent variables.

figure-1 in pro-vlae paper

⬆︎ The figure from Zhiyuan et al. (ICLR 2020) illustrates the ladder architecture of the VAE and the progressive learning of hierarchical representations.

 

Installation

We recommend using mamba (via miniforge) for faster installation of dependencies, but you can also use conda.

git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
cd pytorch-proVLAE

mamba env create -f env.yaml # or conda
mamba activate torch-provlae

 

Usage

You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in scripts directory. If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.


# training with distributed data parallel
# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
    --distributed \
    --mode seq_train \
    --dataset shapes3d \
    --optim adamw \
    --num_ladders 3 \
    --batch_size 128 \
    --num_epochs 15 \
    --learning_rate 5e-4 \
    --beta 8 \
    --z_dim 3 \
    --coff 0.5 \
    --pre_kl \
    --hidden_dim 32 \
    --fade_in_duration 5000 \
    --output_dir ./output/shapes3d/ \
    --data_path ./data

  • Hyper Parameters

Argument Default Description
z_dim 3 Dimension of latent variables
num_ladders 3 Number of ladders (hierarchies) in pro-VLAE
beta 8.0 β parameter for pro-VLAE
learning_rate 5e-4 Learning rate
fade_in_duration 5000 Number of steps for fade-in period
image_size 64 Input image size
chn_num 3 Number of input image channels
train_seq 1 Current training sequence number (indep_train mode only)
batch_size 100 Batch size
num_epochs 1 Number of epochs
hidden_dim 32 Hidden layer dimension
coff 0.5 Coefficient for KL divergence
pre_kl True Use inactive ladder loss
use_kl_annealing False Enable KL annealing (see haofuml/cyclical_annealing)
kl_annealing_mode "linear" KL cycle annealing mode (linear, sigmoid, cosine)
cycle_period 4 The number of annealing
max_kl_weight 1.0 Max KL weight
min_kl_weight 0.1 Min KL weight

  • Training Parameters

Argument Default Description
dataset "shapes3d" Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet)
data_path "./data" Path to dataset storage
output_dir "outputs" Output directory
checkpoint_dir "checkpoints" Checkpooints results directory
recon_dir "reconstructions" Reconstructions results directory
traverse_dir "travesals" Traversal results directory
mode "seq_train" Execution mode ("seq_train", "indep_train", "traverse")
compile_mode "default" PyTorch compilation mode
on_cudnn_benchmark True Enable/disable cuDNN benchmark
optim "adam" Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad)
distributed False enable distributed data parallel
num_workers 4 Number of workers for data loader
use_wandb False Enable wandb tracking and logging
wandb_project "provlae" Wandb project name

Mode descriptions:

  • seq_train: Sequential training from ladder 1 to num_ladders
  • indep_train: Independent training of specified train_seq ladder
  • traverse: Visualize latent space using trained model (need checkpoints)

 

PyTorch optimization options

 

Dataset

Here is a lists of datasets used in the original proVLAE paper, along with additional disentanglement datasets. The datasets are automatically downloaded and preprocessed when you specify the dataset name in the --dataset argument except for ImageNet.

Datasets used in the original proVLAE paper

  1. MNIST: mnist
  2. Disentanglement testing Sprites dataset (dSprites): dsprites
  3. 3D Shapes: shapes3d
  4. Large-scale CelebFaces Attributes (CelebA): celeba

Additional Disentanglement Datasets

  1. MPI3D Disentanglement Datasets: mpi3d
  2. 3DIdent: ident3d

Other Datasets (testing in progress)

  1. Fashion-MNIST: fashionmnist
  2. Describable Textures Dataset (DTD): dtd
  3. 102 Category Flower Dataset: flowers102
  4. ImageNet: imagenet

 

Work in Progress

Hyperparameter optimization (beta, coff, fade-in duration, learning rates) and implementation of disentanglement metrics (MIG for detecting factor splitting, MIG-sup for factor entanglement) are currently under development. Benchmark results will be provided in future updates.

 

License

This repository is licensed under the MIT License - see the LICENSE file for details. This follows the licensing of the original implementation license by Zhiyuan.

 


*This repository is a contribution to AIST (National Institute of Advanced Industrial Science and Technology) project.

Human Informatics and Interaction Research Institute, Neuronrehabilitation Research Group
Shosuke Suzuki, Ryusuke Hayashi