Skip to content

Commit

Permalink
Updated code for the camera-ready version of the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
White-Link committed Mar 15, 2021
1 parent 87f82e8 commit 8d92277
Show file tree
Hide file tree
Showing 21 changed files with 959 additions and 324 deletions.
140 changes: 119 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,104 @@
# PDE-Driven Spatiotemporal Disentanglement

Official implementation of the paper *PDE-Driven Spatiotemporal Disentanglement* (Jérémie Donà,* Jean-Yves Franceschi,* Sylvain Lamprier, Patrick Gallinari).
Official implementation of the paper *PDE-Driven Spatiotemporal Disentanglement* (Jérémie Donà,* Jean-Yves Franceschi,* Sylvain Lamprier, Patrick Gallinari), accepted at ICLR 2021.


## [Article](https://openreview.net/forum?id=vLaHRtHvfFp)

## [Preprint](https://arxiv.org/abs/2008.01352)


## Requirements

All models were trained with Python 3.8.1 and PyTorch 1.4.0 using CUDA 10.1. The `requirements.txt` file lists Python package dependencies.
All models were trained with Python 3.8.1 and PyTorch 1.4.0 using CUDA 10.1.
The `requirements.txt` file lists Python package dependencies.

We obtained all our models thanks to mixed-precision training with Nvidia's [Apex](https://nvidia.github.io/apex/) (v0.1), allowing to accelerate training on the most recent Nvidia GPU architectures. This optimization can be enabled using the command-line options.
We obtained all our models thanks to mixed-precision training with Nvidia's [Apex](https://nvidia.github.io/apex/) (v0.1), allowing to accelerate training on the most recent Nvidia GPU architectures (starting from Volta).
This optimization can be enabled using the command-line options.
We also enabled PyTorch's inetrgated [mixed-precision training package](https://pytorch.org/docs/stable/amp.html) as an experimental feature, which should provide similar results.


## Execution
## Datasets

All scripts should be executed as modules from the root of this folder. For example, the training script can be launched with:
Preprocessing scripts are located in the `var_sep/preprocessing` folder for the WaveEq, WaveEq-100 and Moving MNIST datasets;
- `var_sep.preprocessing.wave.gen_wave` generates the WaveEq dataset;
- `var_sep.preprocessing.wave.gen_pixels` chooses pixels to draw from the WaxeEq dataset to create the WaveEq-100 dataset.

### Moving MNIST

The training dataset is generated on the fly.
The testing set can be generated as an `.npz` file in the directory `$DIR` with the following command:
```bash
python -m var_sep.main
python -m var_sep.preprocessing.mmnist.make_test_set --data_dir $DIR
```

### 3D Warehouse Chairs

## Datasets
The original multi-view dataset can be downloaded at [https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar](https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar).
In order to train and test our model on this dataset, it should be preprocessed to obtain 64x64 cropped images using the following command:

Preprocessing scripts are located in the `var_sep/preprocessing` folder for the WaveEq, WaveEq-100 and Moving MNIST datasets:
- `var_sep.preprocessing.mnist.make_test_set` creates the Moving MNIST testing set;
- `var_sep.preprocessing.chairs.gen_chairs` creates, from the original dataset to download at [https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar](https://www.di.ens.fr/willow/research/seeing3Dchairs/data/rendered_chairs.tar), the 64x64 images used by the model;
- `var_sep.preprocessing.wave.gen_wave` generates the WaveEq dataset;
- `var_sep.preprocessing.wave.gen_pixels` chooses pixels to draw from the WaxeEq dataset to create the WaveEq-100 dataset.
```bash
python -m var_sep.preprocessing.chairs.gen_chairs --data_dir $DIR
```
where `$DIR` is the directory where the dataset was downloaded and extracted.
The preprocessing script will save the processed images in the same location as the original images in the extracted archive.

### TaxiBJ

We used the preprocessed dataset provided by [MIM's authors in their official repository](https://github.com/Yunbo426/MIM).
It consists in four HDF5 files named `BJ${YEAR}_M32x32_T30_InOut.h5` where `$YEAR` is ranges from 13 to 16.

### SST

Regarding SST, we refer the reader to the article in which it was introduced ([https://openreview.net/forum?id=By4HsfWAZ](https://openreview.net/forum?id=By4HsfWAZ)) and its authors, as we do not own the preprocessing script to this date.
We refer the reader to the article in which this dataset was introduced ([https://openreview.net/forum?id=By4HsfWAZ](https://openreview.net/forum?id=By4HsfWAZ)) and its authors, as we do not own the preprocessing script to this date.

### WaveEq & WaveEq-100

WaveEq data are generated in the directory `$DIR` by the following command:
```bash
python -m var_sep.preprocessing.wave.gen_wave --data_dir $DIR
```
and sampled pixels are chosen by the following script:
```bash
python -m var_sep.preprocessing.wave.gen_wave --data_dir $DIR
```


## Training

Please refer to the help message of `main.py`:
In order to train a model on the GPU indexed by `$NDEVICE` with data directory and save directory respectively given by `$DATA_DIR` and `$XP_DIR`, execute the following command:
```bash
python -m var_sep.main --device $NDEVICE --xp_dir $XP_DIR --data_dir $DATA_DIR
```
Options `--apex_amp` and `--torch_amp` can be used to accelerate training (see [requirements](#Requirements)).

Models presented in the paper can be obtained using the following parameters:
- for Moving MNIST:
```bash
--data mnist --epochs 800 --beta1 0.5 --scheduler
```
- for 3D Warehouse Chairs:
```bash
--data chairs --epochs 120 --gain_resnet 0.71 --code_size_t 10 --architecture resnet --decoder_architecture dcgan --lamb_ae 1 --lamb_s 1
```
- for TaxiBJ:
```bash
--data taxibj --nt_cond 4 --nt_pred 4 --lr 4e-5 --batch_size 100 --epochs 550 --scheduler --scheduler_decay 0.2 --scheduler_milestones 250 300 350 400 450 --offset 4 --gain_resnet 0.71 --architecture vgg --lamb_ae 45 --lamb_s 0.0001
```
- for SST:
```bash
--data sst --nt_cond 4 --nt_pred 6 --epochs 30 --code_size_t 64 --code_size_s 196 --gain_res 0.2 --offset 0 --gain_resnet 0.71 --architecture encoderSST --decoder_architecture decoderSST --lamb_ae 1 --lamb_s 100 --lamb_t 5e-6 --skipco --n_blocks 2
```
- for WaveEq:
```bash
--data wave --nt_cond 5 --nt_pred 20 --epochs 250 --batch_size 128 --code_size_t 32 --code_size_s 32 --gain_resnet 0.71 --offset 5 --n_blocks 3 --mixing mul --architecture mlp --enc_hidden_size 1200 --dec_hidden_size 1200 --dec_n_layers 4 --lamb_ae 1
```
- for WaveEq-100:
```bash
--data wave_partial --nt_cond 5 --nt_pred 20 --epochs 250 --batch_size 128 --code_size_t 32 --code_size_s 32 --gain_resnet 0.71 --offset 5 --n_blocks 3 --mixing mul --architecture mlp --enc_hidden_size 2400 --dec_hidden_size 150 --lamb_ae 1
```

Please also refer the help message of the program:
```bash
python -m var_sep.main --help
```
Expand All @@ -43,10 +107,44 @@ which lists options and hyperparameters to train our model.

## Testing

Evaluation scripts on testing sets are located in the `var_sep/test` folder.
- `var_sep.test.mnist.test` evaluates the prediction PSNR and SSIM of the model on Moving MNIST;
- `var_sep.test.mnist.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and digits on Moving MNIST;
- `var_sep.test.chairs.test_disentanglement` evaluates the disentanglement PSNR and SSIM of the model by swapping contents and chairs on 3D Warehouse Chairs;
- `var_sep.sst.wave.test` computes the prediction MSE of the model after 6 and 10 prediction steps on SST;
- `var_sep.test.wave.test` computes the prediction MSE of the model after 40 prediction steps on WaveEq and WaveEq-100;
Please refer to the corresponding help messages for further information.
Trained models can be tested as follows.
These evaluations can be run on GPU using the `--device`options of each script.
Please also refer to the help message of each script for more information.

### Moving MNIST

Prediction performance (MSE, PSNR and SSIM) on Moving MNIST over a number `$HOR` of predicted frames is assessed using the following command:
```bash
python -m var_sep.test.mnist.test --xp_dir $XP_DIR --data_dir $DATA_DIR --nt_pred $HOR
```
For instance, long-term prediction results in the paper corresponds to setting `$HOR` to 95.

Disentanglement performance can be computed in the sawe way:
```bash
python -m var_sep.test.mnist.test_disentanglement --xp_dir $XP_DIR --data_dir $DATA_DIR --nt_pred $HOR
```

### 3D Warehouse Chairs
Disentanglement performance can be computed using the following command similarly to Moving MNIST:
```bash
python -m var_sep.test.chairs.test_disentanglement --xp_dir $XP_DIR --data_dir $DATA_DIR --nt_pred $HOR
```

### TaxiBJ
Prediction MSE can be computed using the following command:
```bash
python -m var_sep.test.taxibj.test --xp_dir $XP_DIR --data_dir $DATA_DIR
```

### SST
Prediction MSE can be computed using the following command:
```bash
python -m var_sep.test.sst.test --xp_dir $XP_DIR --data_dir $DATA_DIR
```

### WaveEq & WaveEq-100

Prediction MSE on both datasets can be computed using the following command:
```bash
python -m var_sep.test.wave.test --xp_dir $XP_DIR --data_dir $DATA_DIR
```
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Requirements

h5py==2.10.0
numpy==1.18.1
netCDF4==1.5.3
pandas==1.1.4
pillow==7.0.0
pyyaml==5.3
scikit-image==0.16.2
scipy==1.4.1
Expand Down
Loading

0 comments on commit 8d92277

Please sign in to comment.