Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
suzuki-2001 authored Nov 21, 2024
1 parent 64b885f commit 091fa22
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ The official code for proVLAE, implemented in TensorFlow, is available [here](ht

 

We modified the code to allow more flexible configuration of VAE architecture by specifying only z_dim, number of ladder layers, and input image size.
This implementation introduces dynamic size management for arbitrary input image sizes; it automatically calculates the maximum possible ladder layers based on input dimensions and adaptively handles latent space dimensionality. The network depth and feature map sizes are also adjusted automatically by calculating appropriate intermediate dimensions, while ensuring a minimum feature map size and proper handling of dimensions during flatten/unflatten operations.

![figure-1 in pro-vlae paper](./md/provlae-figure1.png)

#### Model Architecture
> Figure 1: Progressive learning of hierarchical representations. White blocks and solid lines are VAE
> models at the current progression. α is a fade-in coefficient for blending in the new network component. Gray circles and dash line represents (optional) constraining of the future latent variables.
![figure-1 in pro-vlae paper](./md/provlae-figure1.png)

⬆︎ The figure from Zhiyuan et al. (ICLR 2020) illustrates the ladder architecture of the VAE and the progressive learning of hierarchical representations.

 
Expand Down Expand Up @@ -74,7 +72,7 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/shapes3d/ \
--data_path ./data/shapes3d/
--data_path ./data
```

</br>
Expand All @@ -95,7 +93,12 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
| `num_epochs` | 1 | Number of epochs |
| `hidden_dim` | 32 | Hidden layer dimension |
| `coff` | 0.5 | Coefficient for KL divergence |
| `pre_kl` | True | use inactive ladder loss |
| `pre_kl` | True | Use inactive ladder loss |
| `use_kl_annealing` | False | Enable KL annealing (see [haofuml/cyclical_annealing](https://github.com/haofuml/cyclical_annealing)) |
| `kl_annealing_mode` | "linear" | KL cycle annealing mode (linear, sigmoid, cosine) |
| `cycle_period` | 4 | The number of annealing |
| `max_kl_weight` | 1.0 | Max KL weight |
| `max_kl_weight` | 0.1 | Min KL weight |

</br>

Expand All @@ -115,6 +118,8 @@ torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
| `optim` | "adam" | Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad) |
| `distributed` | False | enable distributed data parallel |
| `num_workers` | 4 | Number of workers for data loader |
| `use_wandb` | False | Enable wandb tracking and logging |
| `wandb_project` | "provlae" | Wandb project name |

</br>

Expand All @@ -133,7 +138,7 @@ Mode descriptions:
- [Asynchronous GPU Copies](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [Tensor Float 32 (>= Ampere)](https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html)
- [Distributed Data Parallel](https://pytorch.org/docs/stable/notes/ddp.html) (experimental)
- [Distributed Data Parallel](https://pytorch.org/docs/stable/notes/ddp.html)


- __Optimizer__: DiffGrad, Lamb, MADGRAD is implemented by [jettify/pytorch-optimizer](https://github.com/jettify/pytorch-optimizer), other optimizer based [torch.optim package](torch.serialization.add_safe_globals([set])).
Expand Down

0 comments on commit 091fa22

Please sign in to comment.