Brian Karrer3 · Evangelos A. Theodorou1 · Ricky T. Q. Chen3
1Georgia Tech 2Weizmann Institute of Science 3FAIR, Meta
Generalized Schrödinger Bridge Matching (GSBM) is a new matching algorithm for learning diffusion models between two distributions with task-specific optimality structures. Examples of task-specific structures include mean-field interaction in population propagation (1st, 2nd figures), geometric prior given LiDAR manifold (3rd figure), or latent-guided unpaired image translation (right figure).
conda env create -f environment.yml
pip install -e .
python train.py experiment=$EXP seed=0,1,2,3,4 -m
where EXP
is one of the settings in configs/experiment/*.yaml
. The commands to generate similar results shown in our paper can be found in scripts/train.sh
. By default, checkpoints and figures are saved under the folder outputs
.
Download the official AFHQ dataset from stargan-v2, then preprocess images with
python afhq_preprocess.py --dir $DIR_AFHQ
where DIR_AFHQ
is the path to AFHQ dataset (e.g., ../stargan-v2/data/afhq
).
Download lidar data and place it in data
folder.
All downloaded files will be stored under the folder data
.
See
notebooks/afhq_sample.ipynb
.
We train GSBM with 4 nodes, each with 8 32GB V100 GPUs.
python train.py experiment=afhq nnodes=4 -m
To sample from a checkpoint $CKPT saved under outputs/multiruns/afhq/$CKPT
, run
python afhq_sample.py --ckpt $CKPT --transfer $TRNSF \
[--nfe $NFE] [--batch-size $BATCH]
where TRNSF
can be either cat2dog
or dog2cat
. By default, we set NFE=1000
and BATCH=512
. To optionally parallelize the sampling across multiple devices, add --partition 0_4
so that the dataset is partitioned into 4 subsets (indices 0,1,2,3) and only run the first partition, i.e. index 0. Similarly, --partition 1_4
run the second partition, and so on. The reconstruction images will be saved under the parent of outputs/multiruns/afhq/$CKPT
, in the folders named samples
and trajs
.
GSBM alternatively solves the Conditional Stochastic Optimal Control (CondSOC) problem and the resulting marginal Matching problem. We implement GSBM on PyTorch Lightning with the following configurations:
- We solve CondSOC and Matching respectively in the validation and training epochs.
pl.Trainer
is instantiated withnum_sanity_val_steps=-1
andcheck_val_every_n_epoch=1
so that the validation epoch is executed before the initial training epoch and after each subsequent training epoch. - The results of CondSOC are gathered in
validation_epoch_end
and stored astrain_data
, which is then used to initializetrain_dataloader
. We setreload_dataloaders_every_n_epochs=1
to refreashtrain_dataloader
with latest CondSOC results. - For multi-GPU training, we distribute CondSOC optimization across each device by setting
replace_sampler_ddp=False
and then instantiatingval_dataloader
on each device with a different seed. - The training direction (forward or backward) is altered in
training_epoch_end
, which is called after the validation epoch.
The overall procedure follows
[validate epoch (sanity)] CondSOC with random coupling
→ [training epoch #0] Matching forward drift
→ [validate epoch #0] CondSOC given forward model coupling
→ [training epoch #1] Matching backward drift
→ [validate epoch #1] CondSOC given backward model coupling
→ [training epoch #2] Matching forward drift
→ ...
If you wish to implement GSBM for your own distribution matching tasks, we recommand fine-tuning the CondSOC optimization independently as in notebooks/example_CondSOC.ipynb
. Once you are happy with the CondSOC results, you can seamlessly integrate it into the main GSBM algorithm.
If you find this repository helpful for your publications, please consider citing our paper:
@inproceedings{liu2024gsbm,
title={{Generalized Schr{\"o}dinger bridge matching}},
author={Liu, Guan-Horng and Lipman, Yaron and Nickel, Maximilian and Karrer, Brian and Theodorou, Evangelos A and Chen, Ricky TQ},
booktitle={International Conference on Learning Representations},
year={2024}
}
The majority of generalized-schrodinger-bridge-matching
is licensed under CC BY-NC, however portions of the project are adapted from other sources and are under separate license terms: files from https://github.com/ghliu/deepgsb is licensed under the Apache 2.0 license, and files from https://github.com/openai/guided-diffusion are licensed under the MIT license.