diff --git a/.gitignore b/.gitignore
index 58ea5eb..751869b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -171,3 +171,5 @@ slurm*.out
*.jpg
notebooks/figures/
+
+.DS_Store
diff --git a/README.md b/README.md
index 849ba5c..b1bbbf7 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-# Conditional Flow Matching
+# TorchCFM: a Conditional Flow Matching library
@@ -23,7 +23,30 @@
## Description
-Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference.
+Conditional Flow Matching (CFM) is a fast way to train continuous normalizing flow (CNF) models. CFM is a simulation-free training objective for continuous normalizing flows that allows conditional generative modeling and speeds up training and inference. CFM's performance closes the gap between CNFs and diffusion models. To spread its use within the machine learning community, we have built a library focused on Flow Matching methods: TorchCFM. TorchCFM is a library showing how Flow Matching methods can be trained and use to deal with image generation, single-cell dynamics and (soon) SO(3) data and tabular data.
+
+
+
+
+
+
+The density, vector field, and trajectories of simulation-free CNF training schemes: mapping 8 Gaussians to two moons (above) and a single Gaussian to two moons (below). Action matching with the same architecture (3x64 MLP with SeLU activations) underfits with the ReLU, SiLU, and SiLU activations as suggested in the [example code](https://github.com/necludov/jam), but it seems to fit better under our training setup (Action-Matching (Swish)).
+
+The models to produce the GIFs are stored in `examples/models` and can be visualized with this notebook: [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/model-comparison-plotting.ipynb).
+
+We also have included an example of unconditional MNIST generation in `examples/notebooks/mnist_example.ipynb` for both deterministic and stochastic generation. [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/mnist_example.ipynb).
+
+## The torchcfm Package
+
+In our version 1 update we have extracted implementations of the relevant flow matching variants into a package `torchcfm`. This allows abstraction of the choice of the conditional distribution `q(z)`. `torchcfm` supplies the following loss functions:
+
+- `ConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = q(x_0) q(x_1)$
+- `ExactOptimalTransportConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi(x_0, x_1)$ where $\\pi$ is an exact optimal transport joint. This is used in \[Tong et al. 2023a\] and \[Poolidan et al. 2023\] as "OT-CFM" and "Multisample FM with Batch OT" respectively.
+- `TargetConditionalFlowMatcher`: $z = x_1$, $q(z) = q(x_1)$ as defined in Lipman et al. 2023, learns a flow from a standard normal Gaussian to data using conditional flows which optimally transport the Gaussian to the datapoint (Note that this does not result in the marginal flow being optimal transport).
+- `SchrodingerBridgeConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi\_\\epsilon(x_0, x_1)$ where $\\pi\_\\epsilon$ is a an entropically regularized OT plan, although in practice this is often approximated by a minibatch OT plan (See Tong et al. 2023b). The flow-matching variant of this where the marginals are equivalent to the Schrodinger Bridge marginals is known as `SB-CFM` \[Tong et al. 2023a\]. When the score is also known and the bridge is stochastic is called \[SF\]2M \[Tong et al. 2023b\]
+- `VariancePreservingConditionalFlowMatcher`: $z = (x_0, x_1)$ $q(z) = q(x_0) q(x_1)$ but with conditional Gaussian probability paths which preserve variance over time using a trigonometric interpolation as presented in \[Albergo et al. 2023a\].
+
+## How to cite
This repository contains the code to reproduce the main experiments and illustrations of two preprints:
@@ -64,63 +87,40 @@ A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y.
-## Examples
-
-![My Image](assets/8gaussians-to-moons.gif)
-
-![My Image](assets/gaussian-to-moons.gif)
-
-The density, vector field, and trajectories of simulation-free CNF training schemes: mapping 8 Gaussians to two moons (above) and a single Gaussian to two moons (below).
-
-The first two methods, variance-preserving SDE (VP-SDE) and flow matching (FM), require a Gaussian source distribution so do not appear in the above example mapping 8 Gaussians distribution to the two moons distribution. Action matching with the same architecture (3x64 MLP with SeLU activations) underfits with the ReLU, SiLU, and SiLU activations as suggested in the [example code](https://github.com/necludov/jam), but it seems to fit better under our training setup (Action-Matching (Swish)).
-
-The models to produce the GIFs are stored in `examples/models` and can be visualized with this notebook: [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/model-comparison-plotting.ipynb).
-
-We also have included an example of unconditional MNIST generation in `examples/notebooks/mnist_example.ipynb` for both deterministic and stochastic generation. [![notebook](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/mnist_example.ipynb).
-
-## The torchcfm Package
-
-In our version 1 update we have extracted implementations of the relevant flow matching variants into a package `torchcfm`. This allows abstraction of the choice of the conditional distribution `q(z)`. `torchcfm` supplies the following loss functions:
-
-- `ConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = q(x_0) q(x_1)
-- `ExactOptimalTransportConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi(x_0, x_1)$ where $\\pi$ is an exact optimal transport joint. This is used in \[Tong et al. 2023a\] and \[Poolidan et al. 2023\] as "OT-CFM" and "Multisample FM with Batch OT" respectively.
-- `TargetConditionalFlowMatcher`: $z = x_1$, $q(z) = q(x_1)$ as defined in Lipman et al. 2023, learns a flow from a standard normal Gaussian to data using conditional flows which optimally transport the Gaussian to the datapoint (Note that this does not result in the marginal flow being optimal transport.
-- `SchrodingerBridgeConditionalFlowMatcher`: $z = (x_0, x_1)$, $q(z) = \\pi\_\\epsilon(x_0, x_1)$ where $\\pi\_\\epsilon$ is a an entropically regularized OT plan, although in practice this is often approximated by a minibatch OT plan (See Tong et al. 2023b). The flow-matching variant of this where the marginals are equivalent to the Schrodinger Bridge marginals is known as `SB-CFM` \[Tong et al. 2023a\]. When the score is also known and the bridge is stochastic is called \[SF\]2M \[Tong et al. 2023b\]
-- `VariancePreservingConditionalFlowMatcher`: $z = (x_0, x_1) $q(z) = q(x_0) q(x_1)$ but with conditional Gaussian probability paths which preserve variance over time using a trigonometric interpolation as presented in \[Albergo et al. 2023a\].
-
## V0 -> V1
-In abstracting out the relevant losses from our `pytorch-lightning` implementation we have moved all of the `pytorch-lightning` code to `runner`. Every command that worked before for lightning should now be run from within the `runner` directory. V0 is now frozen in the `V0` branch.
-
Major Changes:
+- __Added cifar10 examples with an FID of 3.5__
- Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint
- Created `torchcfm` pip installable package
- Moved `pytorch-lightning` implementation and experiments to `runner` directory
- Moved `notebooks` -> `examples`
- Added image generation implementation in both lightning and a notebook in `examples`
-## Related Work
+## Implemented papers
-Relevant papers on simulation-free training of flow models:
+List of implemented papers:
- Flow Matching for Generative Modeling (Lipman et al. 2023) [Paper](https://openreview.net/forum?id=PqvMRDCJT9t)
- Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow (Liu et al. 2023) [Paper](https://openreview.net/forum?id=XVjTT1nw5z) [Code](https://github.com/gnobitab/RectifiedFlow.git)
- Building Normalizing Flows with Stochastic Interpolants (Albergo et al. 2023a) [Paper](https://openreview.net/forum?id=li7qeBbCR1t)
-- Stochastic Interpolants: A Unifying Framework for Flows and Diffusions (Albergo et al. 2023b) [Paper](https://arxiv.org/abs/2303.08797)
- Action Matching: Learning Stochastic Dynamics From Samples (Neklyudov et al. 2022) [Paper](https://arxiv.org/abs/2210.06662) [Code](https://github.com/necludov/jam)
-- Riemannian Flow Matching on General Geometries (Chen et al. 2023) [Paper](https://arxiv.org/abs/2302.03660)
-- Multisample Flow Matching: Straightening Flows with Minibatch Couplings (Pooladian et al. 2023) [Paper](https://arxiv.org/abs/2304.14772)
-
-## Code Contributions
-
-This repo is extracted from a larger private codebase which loses the original commit history which contains work from other authors on the paper.
+- Concurrent work to our OT-CFM method: Multisample Flow Matching: Straightening Flows with Minibatch Couplings (Pooladian et al. 2023) [Paper](https://arxiv.org/abs/2304.14772)
+- Soon: SE(3)-Stochastic Flow Matching for Protein Backbone Generation (Bose et al.) [paper](https://arxiv.org/abs/2310.02391)
+- Soon: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees (Joliecoeur et al.) [paper](https://arxiv.org/abs/2309.09968)
## How to run
Run a simple minimal example here [![Run in Google Colab](https://img.shields.io/static/v1?label=Run%20in&message=Google%20Colab&color=orange&logo=Google%20Cloud)](https://colab.research.google.com/github/atong01/conditional-flow-matching/blob/master/examples/notebooks/training-8gaussians-to-moons.ipynb). Or install the more efficient code locally with these steps.
-Install dependencies
+TorchCFM is now on [pypi](https://pypi.org/project/torchcfm/)! You can install it with:
+
+```bash
+pip install torchcfm
+```
+
+To use the full library with the the different examples, you can install dependencies:
```bash
# clone project
@@ -128,8 +128,8 @@ git clone https://github.com/atong01/conditional-flow-matching.git
cd conditional-flow-matching
# [OPTIONAL] create conda environment
-conda create -n myenv python=3.10
-conda activate myenv
+conda create -n torchcfm python=3.10
+conda activate torchcfm
# install pytorch according to instructions
# https://pytorch.org/get-started/
@@ -138,34 +138,18 @@ conda activate myenv
pip install -r requirements.txt
```
-Note that `torchdyn==1.0.4` is broken on pypi. It may be necessary to install `torchdyn==1.0.3` until this is updated.
-
-Train model with default configuration
+To run our jupyter notebooks, use the following commands after installing our package.
```bash
-cd runner
+# install ipykernel
+conda install -c anaconda ipykernel
-# train on CPU
-python src/train.py trainer=cpu
+# install conda env in jupyter notebook
+python -m ipykernel install --user --name=torchcfm
-# train on GPU
-python src/train.py trainer=gpu
+# launch our notebooks with the torchcfm kernel
```
-Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
-
-```bash
-python src/train.py experiment=experiment_name
-```
-
-You can override any parameter from command line like this
-
-```bash
-python src/train.py trainer.max_epochs=20 datamodule.batch_size=64
-```
-
-You can also train a large set of models in parallel with SLURM as shown in `scripts/two-dim-cfm.sh` which trains the models used in the first 3 lines of Table 2.
-
## Project Structure
The directory structure of new project looks like this:
@@ -173,44 +157,13 @@ The directory structure of new project looks like this:
```
│
-├── data <- Project data
-│
-├── logs <- Logs generated by hydra and lightning loggers
-│
-├── examples <- Jupyter notebooks. Naming convention is a number (for ordering),
-│ the creator's initials, and a short `-` delimited description,
-│ e.g. `1.0-jqp-initial-data-exploration.ipynb`.
+├── examples <- Jupyter notebooks
+| ├── cifar10 <- Cifar10 experiments
+│ ├── notebooks <- Diverse examples with notebooks
│
-│── runner <- Shell scripts
-| ├── scripts <- Shell scripts
-│ ├── configs <- Hydra configuration files
-│ │ ├── callbacks <- Callbacks configs
-│ │ ├── debug <- Debugging configs
-│ │ ├── datamodule <- Datamodule configs
-│ │ ├── experiment <- Experiment configs
-│ │ ├── extras <- Extra utilities configs
-│ │ ├── hparams_search <- Hyperparameter search configs
-│ │ ├── hydra <- Hydra configs
-│ │ ├── launcher <- Hydra launcher configs
-│ │ ├── local <- Local configs
-│ │ ├── logger <- Logger configs
-│ │ ├── model <- Model configs
-│ │ ├── paths <- Project paths configs
-│ │ ├── trainer <- Trainer configs
-│ │ │
-│ │ ├── eval.yaml <- Main config for evaluation
-│ │ └── train.yaml <- Main config for training
-│ ├── src <- Source code
-│ │ ├── datamodules <- Lightning datamodules
-│ │ ├── models <- Lightning models
-│ │ ├── utils <- Utility scripts
-│ │ │
-│ │ ├── eval.py <- Run evaluation
-│ │ └── train.py <- Run training
-│ │
-│ ├── tests <- Tests of any kind
+│── runner <- Everything related to the original version (V0) of the library
│
-|── torchcfm <- Shell scripts
+|── torchcfm <- Code base of our Flow Matching methods
| ├── conditional_flow_matching.py <- CFM classes
│ ├── models <- Model architectures
│ │ ├── models <- Models for 2D examples
@@ -224,332 +177,14 @@ The directory structure of new project looks like this:
└── README.md
```
-## ⚡ Your Superpowers
-
-
-Override any config parameter from command line
-
-```bash
-python train.py trainer.max_epochs=20 model.optimizer.lr=1e-4
-```
-
-> **Note**: You can also add new parameters with `+` sign.
-
-```bash
-python train.py +model.new_param="owo"
-```
-
-
-
-
-Train on CPU, GPU, multi-GPU and TPU
-
-```bash
-# train on CPU
-python train.py trainer=cpu
-
-# train on 1 GPU
-python train.py trainer=gpu
-
-# train on TPU
-python train.py +trainer.tpu_cores=8
-
-# train with DDP (Distributed Data Parallel) (4 GPUs)
-python train.py trainer=ddp trainer.devices=4
-
-# train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
-python train.py trainer=ddp trainer.devices=4 trainer.num_nodes=2
-
-# simulate DDP on CPU processes
-python train.py trainer=ddp_sim trainer.devices=2
-
-# accelerate training on mac
-python train.py trainer=mps
-```
-
-> **Warning**: Currently there are problems with DDP mode, read [this issue](https://github.com/ashleve/lightning-hydra-template/issues/393) to learn more.
-
-
-
-
-Train with mixed precision
-
-```bash
-# train with pytorch native automatic mixed precision (AMP)
-python train.py trainer=gpu +trainer.precision=16
-```
-
-
-
-
-
-
-Train model with any logger available in PyTorch Lightning, like W&B or Tensorboard
-
-```yaml
-# set project and entity names in `configs/logger/wandb`
-wandb:
- project: "your_project_name"
- entity: "your_wandb_team_name"
-```
-
-```bash
-# train model with Weights&Biases (link to wandb dashboard should appear in the terminal)
-python train.py logger=wandb
-```
-
-> **Note**: Lightning provides convenient integrations with most popular logging frameworks. Learn more [here](#experiment-tracking).
-
-> **Note**: Using wandb requires you to [setup account](https://www.wandb.com/) first. After that just complete the config as below.
-
-> **Note**: Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.
-
-
-
-
-Train model with chosen experiment config
-
-```bash
-python train.py experiment=example
-```
-
-> **Note**: Experiment configs are placed in [configs/experiment/](configs/experiment/).
-
-
-
-
-Attach some callbacks to run
-
-```bash
-python train.py callbacks=default
-```
-
-> **Note**: Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).
-
-> **Note**: Callbacks configs are placed in [configs/callbacks/](configs/callbacks/).
-
-
-
-
-Use different tricks available in Pytorch Lightning
-
-```yaml
-# gradient clipping may be enabled to avoid exploding gradients
-python train.py +trainer.gradient_clip_val=0.5
-
-# run validation loop 4 times during a training epoch
-python train.py +trainer.val_check_interval=0.25
-
-# accumulate gradients
-python train.py +trainer.accumulate_grad_batches=10
-
-# terminate training after 12 hours
-python train.py +trainer.max_time="00:12:00:00"
-```
-
-> **Note**: PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
-
-
-
-
-Easily debug
-
-```bash
-# runs 1 epoch in default debugging mode
-# changes logging directory to `logs/debugs/...`
-# sets level of all command line loggers to 'DEBUG'
-# enforces debug-friendly configuration
-python train.py debug=default
-
-# run 1 train, val and test loop, using only 1 batch
-python train.py debug=fdr
-
-# print execution time profiling
-python train.py debug=profiler
-
-# try overfitting to 1 batch
-python train.py debug=overfit
-
-# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
-python train.py +trainer.detect_anomaly=true
-
-# log second gradient norm of the model
-python train.py +trainer.track_grad_norm=2
-
-# use only 20% of the data
-python train.py +trainer.limit_train_batches=0.2 \
-+trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
-```
-
-> **Note**: Visit [configs/debug/](configs/debug/) for different debugging configs.
-
-
-
-
-Resume training from checkpoint
-
-```yaml
-python train.py ckpt_path="/path/to/ckpt/name.ckpt"
-```
-
-> **Note**: Checkpoint can be either path or URL.
-
-> **Note**: Currently loading ckpt doesn't resume logger experiment, but it will be supported in future Lightning release.
-
-
-
-
-Evaluate checkpoint on test dataset
-
-```yaml
-python eval.py ckpt_path="/path/to/ckpt/name.ckpt"
-```
-
-> **Note**: Checkpoint can be either path or URL.
-
-
-
-
-Create a sweep over hyperparameters
-
-```bash
-# this will run 6 experiments one after the other,
-# each with different combination of batch_size and learning rate
-python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
-```
-
-> **Note**: Hydra composes configs lazily at job launch time. If you change code or configs after launching a job/sweep, the final composed configs might be impacted.
-
-
-
-
-Create a sweep over hyperparameters with Optuna
-
-```bash
-# this will run hyperparameter search defined in `configs/hparams_search/mnist_optuna.yaml`
-# over chosen experiment config
-python train.py -m hparams_search=mnist_optuna experiment=example
-```
-
-> **Note**: Using [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) doesn't require you to add any boilerplate to your code, everything is defined in a [single config file](configs/hparams_search/mnist_optuna.yaml).
-
-> **Warning**: Optuna sweeps are not failure-resistant (if one job crashes then the whole sweep crashes).
-
-
-
-
-Execute all experiments from folder
-
-```bash
-python train.py -m 'experiment=glob(*)'
-```
-
-> **Note**: Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command above executes all experiments from [configs/experiment/](configs/experiment/).
-
-
-
-
-Execute run for multiple different seeds
-
-```bash
-python train.py -m seed=1,2,3,4,5 trainer.deterministic=True logger=csv tags=["benchmark"]
-```
-
-> **Note**: `trainer.deterministic=True` makes pytorch more deterministic but impacts the performance.
-
-
-
-
-Execute sweep on a remote AWS cluster
-
-> **Note**: This should be achievable with simple config using [Ray AWS launcher for Hydra](https://hydra.cc/docs/next/plugins/ray_launcher). Example is not implemented in this template.
-
-
-
-
-
-
-Use Hydra tab completion
-
-> **Note**: Hydra allows you to autocomplete config argument overrides in shell as you write them, by pressing `tab` key. Read the [docs](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion).
-
-
-
-
-Apply pre-commit hooks
-
-```bash
-pre-commit run -a
-```
-
-> **Note**: Apply pre-commit hooks to do things like auto-formatting code and configs, performing code analysis or removing output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
-
-
-
-
-Run tests
-
-```bash
-# run all tests
-pytest
-
-# run tests from specific file
-pytest tests/test_train.py
-
-# run all tests except the ones marked as slow
-pytest -k "not slow"
-```
-
-
-
-
-Use tags
-
-Each experiment should be tagged in order to easily filter them across files or in logger UI:
-
-```bash
-python train.py tags=["mnist","experiment_X"]
-```
-
-If no tags are provided, you will be asked to input them from command line:
-
-```bash
->>> python train.py tags=[]
-[2022-07-11 15:40:09,358][src.utils.utils][INFO] - Enforcing tags!
-[2022-07-11 15:40:09,359][src.utils.rich_utils][WARNING] - No tags provided in config. Prompting user to input tags...
-Enter a list of comma separated tags (dev):
-```
-
-If no tags are provided for multirun, an error will be raised:
-
-```bash
->>> python train.py -m +x=1,2,3 tags=[]
-ValueError: Specify tags before launching a multirun!
-```
-
-> **Note**: Appending lists from command line is currently not supported in hydra :(
-
-
+## ❤️ Code Contributions
-
+This toolbox has been created and is maintained by
-## ❤️ Contributions
+- [Alexander Tong](http://alextong.net)
+- [Kilian Fatras](http://kilianfatras.github.io)
-Have a question? Found a bug? Missing a specific feature? Feel free to file a new issue, discussion or PR with respective title and description.
+It was initiated from a larger private codebase which loses the original commit history which contains work from other authors of the papers.
Before making an issue, please verify that:
diff --git a/assets/169_generated_samples_otcfm.gif b/assets/169_generated_samples_otcfm.gif
new file mode 100644
index 0000000..eea1ee0
Binary files /dev/null and b/assets/169_generated_samples_otcfm.gif differ
diff --git a/assets/169_generated_samples_otcfm.png b/assets/169_generated_samples_otcfm.png
new file mode 100644
index 0000000..d33ec09
Binary files /dev/null and b/assets/169_generated_samples_otcfm.png differ
diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md
index d095336..70319f8 100644
--- a/examples/cifar10/README.md
+++ b/examples/cifar10/README.md
@@ -1,13 +1,47 @@
# CIFAR-10 experiments using TorchCFM
-This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). It is a repository in construction and we will add more features and details in the future (including FID computations and pre-trained weights). We have followed the experimental details provided in [2](https://openreview.net/forum?id=PqvMRDCJT9t).
+This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.
-To reproduce the experiment and save the weights, install the requirements from the main repository and then run:
+
+
+
+
+To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):
+
+- For the OT-Conditional Flow Matching method:
+
+```bash
+python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
+```
+
+- For the Conditional Flow Matching method:
+
+```bash
+python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
+```
+
+- For the original Flow Matching method:
```bash
-python3 train_cifar10.py
+python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```
+To compute the FID from the OT-CFM model at end of training, run:
+
+```bash
+python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
+```
+
+For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:
+
+- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)
+
+- [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt)
+
+- [fm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/fm_cifar10_weights_step_400000.pt)
+
+To recompute the FID, change the PATH variable with where you have saved the downloaded weights.
+
If you find this code useful in your research, please cite the following papers (expand for BibTeX):
diff --git a/examples/cifar10/compute_fid.py b/examples/cifar10/compute_fid.py
new file mode 100644
index 0000000..8d23bf9
--- /dev/null
+++ b/examples/cifar10/compute_fid.py
@@ -0,0 +1,105 @@
+# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.
+
+# Authors: Kilian Fatras
+# Alexander Tong
+
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import torch
+from absl import app, flags
+from cleanfid import fid
+from torchdiffeq import odeint
+from torchdyn.core import NeuralODE
+
+from torchcfm.models.unet.unet import UNetModelWrapper
+
+FLAGS = flags.FLAGS
+# UNet
+flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
+
+# Training
+flags.DEFINE_bool("parallel", False, help="multi gpu training")
+flags.DEFINE_string("input_dir", "./results", help="output_directory")
+flags.DEFINE_string("model", "otcfm", help="flow matching model type")
+flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
+flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
+flags.DEFINE_integer("step", 400000, help="training steps")
+flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
+flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
+FLAGS(sys.argv)
+
+
+# Define the model
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+
+new_net = UNetModelWrapper(
+ dim=(3, 32, 32),
+ num_res_blocks=2,
+ num_channels=FLAGS.num_channel,
+ channel_mult=[1, 2, 2, 2],
+ num_heads=4,
+ num_head_channels=64,
+ attention_resolutions="16",
+ dropout=0.1,
+).to(device)
+
+
+# Load the model
+PATH = f"{FLAGS.input_dir}/{FLAGS.model}/cifar10_weights_step_{FLAGS.step}.pt"
+print("path: ", PATH)
+checkpoint = torch.load(PATH)
+state_dict = checkpoint["ema_model"]
+try:
+ new_net.load_state_dict(state_dict)
+except RuntimeError:
+ from collections import OrderedDict
+
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ new_state_dict[k[7:]] = v
+ new_net.load_state_dict(new_state_dict)
+new_net.eval()
+
+
+# Define the integration method if euler is used
+if FLAGS.integration_method == "euler":
+ node = NeuralODE(new_net, solver=FLAGS.integration_method)
+
+
+def gen_1_img(unused_latent):
+ with torch.no_grad():
+ x = torch.randn(500, 3, 32, 32).to(device)
+ if FLAGS.integration_method == "euler":
+ print("Use method: ", FLAGS.integration_method)
+ t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1).to(device)
+ traj = node.trajectory(x, t_span=t_span)
+ else:
+ print("Use method: ", FLAGS.integration_method)
+ t_span = torch.linspace(0, 1, 2).to(device)
+ traj = odeint(
+ new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
+ )
+ traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
+ img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
+ return img
+
+
+print("Start computing FID")
+score = fid.compute_fid(
+ gen=gen_1_img,
+ dataset_name="cifar10",
+ batch_size=500,
+ dataset_res=32,
+ num_gen=FLAGS.num_gen,
+ dataset_split="train",
+ mode="legacy_tensorflow",
+)
+print()
+print("FID has been computed")
+# print()
+# print("Total NFE: ", new_net.nfe)
+print()
+print("FID: ", score)
diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py
index 742aefc..9a0cead 100644
--- a/examples/cifar10/train_cifar10.py
+++ b/examples/cifar10/train_cifar10.py
@@ -1,112 +1,174 @@
+# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.
+
+# Authors: Kilian Fatras
+# Alexander Tong
+
+import copy
import os
-import matplotlib.pyplot as plt
-import numpy as np
import torch
-from timm import scheduler
+from absl import app, flags
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
-from torchvision.utils import make_grid
-from tqdm import tqdm
+from tqdm import trange
+from utils_cifar import ema, generate_samples, infiniteloop
from torchcfm.conditional_flow_matching import (
+ ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
+ TargetConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
-savedir = "weights/reproduced/"
-os.makedirs(savedir, exist_ok=True)
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("model", "otcfm", help="flow matching model type")
+flags.DEFINE_string("output_dir", "./results/", help="output_directory")
+# UNet
+flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
+
+# Training
+flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4
+flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
+flags.DEFINE_integer(
+ "total_steps", 400001, help="total training steps"
+) # Lipman et al uses 400k but double batch size
+flags.DEFINE_integer("img_size", 32, help="image size")
+flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
+flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
+flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
+flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
+flags.DEFINE_bool("parallel", False, help="multi gpu training")
+
+# Evaluation
+flags.DEFINE_integer(
+ "save_step",
+ 20000,
+ help="frequency of saving checkpoints, 0 to disable during training",
+)
+flags.DEFINE_integer(
+ "eval_step", 0, help="frequency of evaluating model, 0 to disable during training"
+)
+flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation")
+
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
-batch_size = 256
-n_epochs = 1000
-
-transform = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
-)
-
-trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
-trainloader = torch.utils.data.DataLoader(
- trainset, batch_size=batch_size, shuffle=True, num_workers=1
-)
-num_iter_per_epoch = int(50000 / batch_size)
-
-#################################
-# OT-CFM
-#################################
-
-sigma = 0.0
-model = UNetModelWrapper(
- dim=(3, 32, 32),
- num_res_blocks=2,
- num_channels=256,
- channel_mult=[1, 2, 2, 2],
- num_heads=4,
- num_head_channels=64,
- attention_resolutions="16",
- dropout=0,
-).to(device)
-
-if torch.cuda.device_count() > 1:
- print("Let's use", torch.cuda.device_count(), "GPUs!")
- # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
- model = torch.nn.DataParallel(model).cuda()
-
-
-optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
-opt_scheduler = scheduler.PolyLRScheduler(
- warmup_t=45000, warmup_lr_init=1e-8, t_initial=196 * n_epochs, optimizer=optimizer
-)
-# FM = ConditionalFlowMatcher(sigma=sigma)
-FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
-
-for epoch in tqdm(range(n_epochs)):
- for i, data in enumerate(trainloader):
- optimizer.zero_grad()
- x1 = data[0].to(device)
- x0 = torch.randn_like(x1)
- t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
- vt = model(t, xt)
- loss = torch.mean((vt - ut) ** 2)
- loss.backward()
- optimizer.step()
- opt_scheduler.step(epoch * (num_iter_per_epoch + 1) + i)
-
- # Saving the weights
- if (epoch + 1) % 100 == 0:
- print(i)
- torch.save(
- {
- "epoch": epoch,
- "model_state_dict": model.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "loss": loss,
- },
- savedir + f"reproduced_cifar10_weights_epoch_{epoch}.pt",
- )
+def warmup_lr(step):
+ return min(step, FLAGS.warmup) / FLAGS.warmup
-if torch.cuda.device_count() > 1:
- print("Send the model over 1 GPU for inference")
- model = model.module.to(device)
-node = NeuralODE(model, solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
+def train(argv):
+ print(
+ "lr, total_steps, ema decay, save_step:",
+ FLAGS.lr,
+ FLAGS.total_steps,
+ FLAGS.ema_decay,
+ FLAGS.save_step,
+ )
-with torch.no_grad():
- traj = node.trajectory(
- torch.randn(60, 3, 32, 32).to(device),
- t_span=torch.linspace(0, 1, 100).to(device),
+ # DATASETS/DATALOADER
+ dataset = datasets.CIFAR10(
+ root="./data",
+ train=True,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+ ),
)
-grid = make_grid(
- traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1),
- value_range=(-1, 1),
- padding=0,
- nrow=10,
-)
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ shuffle=True,
+ num_workers=FLAGS.num_workers,
+ drop_last=True,
+ )
+
+ datalooper = infiniteloop(dataloader)
+
+ # MODELS
+ net_model = UNetModelWrapper(
+ dim=(3, 32, 32),
+ num_res_blocks=2,
+ num_channels=FLAGS.num_channel,
+ channel_mult=[1, 2, 2, 2],
+ num_heads=4,
+ num_head_channels=64,
+ attention_resolutions="16",
+ dropout=0.1,
+ ).to(
+ device
+ ) # new dropout + bs of 128
+
+ ema_model = copy.deepcopy(net_model)
+ optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
+ sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
+ if FLAGS.parallel:
+ net_model = torch.nn.DataParallel(net_model)
+ ema_model = torch.nn.DataParallel(ema_model)
+
+ net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint")
+ ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint")
+ # show model size
+ model_size = 0
+ for param in net_model.parameters():
+ model_size += param.data.nelement()
+ print("Model params: %.2f M" % (model_size / 1024 / 1024))
+
+ #################################
+ # OT-CFM
+ #################################
+
+ sigma = 0.0
+ if FLAGS.model == "otcfm":
+ FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
+ elif FLAGS.model == "cfm":
+ FM = ConditionalFlowMatcher(sigma=sigma)
+ elif FLAGS.model == "fm":
+ FM = TargetConditionalFlowMatcher(sigma=sigma)
+ else:
+ raise NotImplementedError(
+ f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'cfm', 'fm']"
+ )
-img = grid.detach().cpu() / 2 + 0.5 # unnormalize
-npimg = img.numpy()
-plt.imshow(np.transpose(npimg, (1, 2, 0)))
-plt.savefig(savedir + "generated_cifar_reproduced.png")
+ savedir = FLAGS.output_dir + FLAGS.model + "/"
+ os.makedirs(savedir, exist_ok=True)
+
+ with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar:
+ for step in pbar:
+ optim.zero_grad()
+ x1 = next(datalooper).to(device)
+ x0 = torch.randn_like(x1)
+ t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
+ vt = net_model(t, xt)
+ loss = torch.mean((vt - ut) ** 2)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
+ optim.step()
+ sched.step()
+ ema(net_model, ema_model, FLAGS.ema_decay) # new
+
+ # sample and Saving the weights
+ if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
+ generate_samples(net_node, net_model, savedir, step, net_="normal")
+ generate_samples(ema_node, ema_model, savedir, step, net_="ema")
+ torch.save(
+ {
+ "net_model": net_model.state_dict(),
+ "ema_model": ema_model.state_dict(),
+ "sched": sched.state_dict(),
+ "optim": optim.state_dict(),
+ "step": step,
+ },
+ savedir + f"cifar10_weights_step_{step}.pt",
+ )
+
+
+if __name__ == "__main__":
+ app.run(train)
diff --git a/examples/cifar10/utils_cifar.py b/examples/cifar10/utils_cifar.py
new file mode 100644
index 0000000..bc47cbb
--- /dev/null
+++ b/examples/cifar10/utils_cifar.py
@@ -0,0 +1,37 @@
+import torch
+from torchdyn.core import NeuralODE
+
+# from torchvision.transforms import ToPILImage
+from torchvision.utils import make_grid, save_image
+
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda" if use_cuda else "cpu")
+
+
+def generate_samples(node_, model, savedir, step, net_="normal"):
+ model.eval()
+ with torch.no_grad():
+ traj = node_.trajectory(
+ torch.randn(64, 3, 32, 32).to(device),
+ t_span=torch.linspace(0, 1, 100).to(device),
+ )
+ traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
+ traj = traj / 2 + 0.5
+ save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8)
+
+ model.train()
+
+
+def ema(source, target, decay):
+ source_dict = source.state_dict()
+ target_dict = target.state_dict()
+ for key in source_dict.keys():
+ target_dict[key].data.copy_(
+ target_dict[key].data * decay + source_dict[key].data * (1 - decay)
+ )
+
+
+def infiniteloop(dataloader):
+ while True:
+ for x, y in iter(dataloader):
+ yield x
diff --git a/requirements.txt b/requirements.txt
index 293ba8c..1b42df3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,8 @@ scipy
scikit-learn
scprep
scanpy
-timm
-torchdyn>=1.0.5 # 1.0.4 is broken on pypi
+torchdyn>=1.0.7 # 1.0.4 is broken on pypi
pot
-torchdiffeq==0.2.3
+torchdiffeq
+absl-py
+clean-fid
diff --git a/runner/README.md b/runner/README.md
index e69de29..da6c4c1 100644
--- a/runner/README.md
+++ b/runner/README.md
@@ -0,0 +1,393 @@
+In abstracting out the relevant losses from our `pytorch-lightning` implementation we have moved all of the `pytorch-lightning` code to `runner`. Every command that worked before for lightning should now be run from within the `runner` directory. V0 is now frozen in the `V0` branch.
+
+Train model with default configuration
+
+```bash
+cd runner
+
+# train on CPU
+python src/train.py trainer=cpu
+
+# train on GPU
+python src/train.py trainer=gpu
+```
+
+Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
+
+```bash
+python src/train.py experiment=experiment_name
+```
+
+You can override any parameter from command line like this
+
+```bash
+python src/train.py trainer.max_epochs=20 datamodule.batch_size=64
+```
+
+You can also train a large set of models in parallel with SLURM as shown in `scripts/two-dim-cfm.sh` which trains the models used in the first 3 lines of Table 2.
+
+## Code Contributions
+
+This repo is extracted from a larger private codebase which loses the original commit history which contains work from other authors of the papers.
+
+## Project Structure
+
+The directory structure of new project looks like this:
+
+```
+│
+│── runner <- Shell scripts
+| ├── data <- Project data
+| ├── logs <- Logs generated by hydra and lightning loggers
+| ├── scripts <- Shell scripts
+│ ├── configs <- Hydra configuration files
+│ │ ├── callbacks <- Callbacks configs
+│ │ ├── debug <- Debugging configs
+│ │ ├── datamodule <- Datamodule configs
+│ │ ├── experiment <- Experiment configs
+│ │ ├── extras <- Extra utilities configs
+│ │ ├── hparams_search <- Hyperparameter search configs
+│ │ ├── hydra <- Hydra configs
+│ │ ├── launcher <- Hydra launcher configs
+│ │ ├── local <- Local configs
+│ │ ├── logger <- Logger configs
+│ │ ├── model <- Model configs
+│ │ ├── paths <- Project paths configs
+│ │ ├── trainer <- Trainer configs
+│ │ │
+│ │ ├── eval.yaml <- Main config for evaluation
+│ │ └── train.yaml <- Main config for training
+│ ├── src <- Source code
+│ │ ├── datamodules <- Lightning datamodules
+│ │ ├── models <- Lightning models
+│ │ ├── utils <- Utility scripts
+│ │ │
+│ │ ├── eval.py <- Run evaluation
+│ │ └── train.py <- Run training
+│ │
+│ ├── tests <- Tests of any kind
+│ └── README.md
+```
+
+## ⚡ Your Superpowers
+
+
+Override any config parameter from command line
+
+```bash
+python train.py trainer.max_epochs=20 model.optimizer.lr=1e-4
+```
+
+> **Note**: You can also add new parameters with `+` sign.
+
+```bash
+python train.py +model.new_param="owo"
+```
+
+
+
+
+Train on CPU, GPU, multi-GPU and TPU
+
+```bash
+# train on CPU
+python train.py trainer=cpu
+
+# train on 1 GPU
+python train.py trainer=gpu
+
+# train on TPU
+python train.py +trainer.tpu_cores=8
+
+# train with DDP (Distributed Data Parallel) (4 GPUs)
+python train.py trainer=ddp trainer.devices=4
+
+# train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
+python train.py trainer=ddp trainer.devices=4 trainer.num_nodes=2
+
+# simulate DDP on CPU processes
+python train.py trainer=ddp_sim trainer.devices=2
+
+# accelerate training on mac
+python train.py trainer=mps
+```
+
+> **Warning**: Currently there are problems with DDP mode, read [this issue](https://github.com/ashleve/lightning-hydra-template/issues/393) to learn more.
+
+
+
+
+Train with mixed precision
+
+```bash
+# train with pytorch native automatic mixed precision (AMP)
+python train.py trainer=gpu +trainer.precision=16
+```
+
+
+
+
+
+
+Train model with any logger available in PyTorch Lightning, like W&B or Tensorboard
+
+```yaml
+# set project and entity names in `configs/logger/wandb`
+wandb:
+ project: "your_project_name"
+ entity: "your_wandb_team_name"
+```
+
+```bash
+# train model with Weights&Biases (link to wandb dashboard should appear in the terminal)
+python train.py logger=wandb
+```
+
+> **Note**: Lightning provides convenient integrations with most popular logging frameworks. Learn more [here](#experiment-tracking).
+
+> **Note**: Using wandb requires you to [setup account](https://www.wandb.com/) first. After that just complete the config as below.
+
+> **Note**: Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.
+
+
+
+
+Train model with chosen experiment config
+
+```bash
+python train.py experiment=example
+```
+
+> **Note**: Experiment configs are placed in [configs/experiment/](configs/experiment/).
+
+
+
+
+Attach some callbacks to run
+
+```bash
+python train.py callbacks=default
+```
+
+> **Note**: Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).
+
+> **Note**: Callbacks configs are placed in [configs/callbacks/](configs/callbacks/).
+
+
+
+
+Use different tricks available in Pytorch Lightning
+
+```yaml
+# gradient clipping may be enabled to avoid exploding gradients
+python train.py +trainer.gradient_clip_val=0.5
+
+# run validation loop 4 times during a training epoch
+python train.py +trainer.val_check_interval=0.25
+
+# accumulate gradients
+python train.py +trainer.accumulate_grad_batches=10
+
+# terminate training after 12 hours
+python train.py +trainer.max_time="00:12:00:00"
+```
+
+> **Note**: PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
+
+
+
+
+Easily debug
+
+```bash
+# runs 1 epoch in default debugging mode
+# changes logging directory to `logs/debugs/...`
+# sets level of all command line loggers to 'DEBUG'
+# enforces debug-friendly configuration
+python train.py debug=default
+
+# run 1 train, val and test loop, using only 1 batch
+python train.py debug=fdr
+
+# print execution time profiling
+python train.py debug=profiler
+
+# try overfitting to 1 batch
+python train.py debug=overfit
+
+# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
+python train.py +trainer.detect_anomaly=true
+
+# log second gradient norm of the model
+python train.py +trainer.track_grad_norm=2
+
+# use only 20% of the data
+python train.py +trainer.limit_train_batches=0.2 \
++trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
+```
+
+> **Note**: Visit [configs/debug/](configs/debug/) for different debugging configs.
+
+
+
+
+Resume training from checkpoint
+
+```yaml
+python train.py ckpt_path="/path/to/ckpt/name.ckpt"
+```
+
+> **Note**: Checkpoint can be either path or URL.
+
+> **Note**: Currently loading ckpt doesn't resume logger experiment, but it will be supported in future Lightning release.
+
+
+
+
+Evaluate checkpoint on test dataset
+
+```yaml
+python eval.py ckpt_path="/path/to/ckpt/name.ckpt"
+```
+
+> **Note**: Checkpoint can be either path or URL.
+
+
+
+
+Create a sweep over hyperparameters
+
+```bash
+# this will run 6 experiments one after the other,
+# each with different combination of batch_size and learning rate
+python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
+```
+
+> **Note**: Hydra composes configs lazily at job launch time. If you change code or configs after launching a job/sweep, the final composed configs might be impacted.
+
+
+
+
+Create a sweep over hyperparameters with Optuna
+
+```bash
+# this will run hyperparameter search defined in `configs/hparams_search/mnist_optuna.yaml`
+# over chosen experiment config
+python train.py -m hparams_search=mnist_optuna experiment=example
+```
+
+> **Note**: Using [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) doesn't require you to add any boilerplate to your code, everything is defined in a [single config file](configs/hparams_search/mnist_optuna.yaml).
+
+> **Warning**: Optuna sweeps are not failure-resistant (if one job crashes then the whole sweep crashes).
+
+
+
+
+Execute all experiments from folder
+
+```bash
+python train.py -m 'experiment=glob(*)'
+```
+
+> **Note**: Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command above executes all experiments from [configs/experiment/](configs/experiment/).
+
+
+
+
+Execute run for multiple different seeds
+
+```bash
+python train.py -m seed=1,2,3,4,5 trainer.deterministic=True logger=csv tags=["benchmark"]
+```
+
+> **Note**: `trainer.deterministic=True` makes pytorch more deterministic but impacts the performance.
+
+
+
+
+Execute sweep on a remote AWS cluster
+
+> **Note**: This should be achievable with simple config using [Ray AWS launcher for Hydra](https://hydra.cc/docs/next/plugins/ray_launcher). Example is not implemented in this template.
+
+
+
+
+
+
+Use Hydra tab completion
+
+> **Note**: Hydra allows you to autocomplete config argument overrides in shell as you write them, by pressing `tab` key. Read the [docs](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion).
+
+
+
+
+Apply pre-commit hooks
+
+```bash
+pre-commit run -a
+```
+
+> **Note**: Apply pre-commit hooks to do things like auto-formatting code and configs, performing code analysis or removing output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
+
+
+
+
+Run tests
+
+```bash
+# run all tests
+pytest
+
+# run tests from specific file
+pytest tests/test_train.py
+
+# run all tests except the ones marked as slow
+pytest -k "not slow"
+```
+
+
+
+
+Use tags
+
+Each experiment should be tagged in order to easily filter them across files or in logger UI:
+
+```bash
+python train.py tags=["mnist","experiment_X"]
+```
+
+If no tags are provided, you will be asked to input them from command line:
+
+```bash
+>>> python train.py tags=[]
+[2022-07-11 15:40:09,358][src.utils.utils][INFO] - Enforcing tags!
+[2022-07-11 15:40:09,359][src.utils.rich_utils][WARNING] - No tags provided in config. Prompting user to input tags...
+Enter a list of comma separated tags (dev):
+```
+
+If no tags are provided for multirun, an error will be raised:
+
+```bash
+>>> python train.py -m +x=1,2,3 tags=[]
+ValueError: Specify tags before launching a multirun!
+```
+
+> **Note**: Appending lists from command line is currently not supported in hydra :(
+
+
+
+
diff --git a/setup.py b/setup.py
index 29ce294..f997a37 100644
--- a/setup.py
+++ b/setup.py
@@ -4,15 +4,36 @@
from setuptools import find_packages, setup
+install_requires = [
+ "torch>=1.11.0",
+ "torchvision>=0.11.0",
+ "lightning-bolts",
+ "matplotlib",
+ "numpy",
+ "scipy",
+ "scikit-learn",
+ "scprep",
+ "scanpy",
+ "torchdyn",
+ "pot",
+ "torchdiffeq",
+ "absl-py",
+ "clean-fid",
+]
+
version_py = os.path.join(os.path.dirname(__file__), "torchcfm", "version.py")
version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip()
+readme = open("README.md", encoding="utf8").read()
setup(
name="torchcfm",
version=version,
description="Conditional Flow Matching for Fast Continuous Normalizing Flow Training.",
- author="Alexander Tong",
+ author="Alexander Tong, Kilian Fatras",
author_email="alexandertongdev@gmail.com",
url="https://github.com/atong01/conditional-flow-matching",
- install_requires=["torch", "pot", "numpy", "torchdyn"],
+ install_requires=install_requires,
+ license="MIT",
+ long_description=readme,
+ long_description_content_type="text/markdown",
packages=find_packages(),
)
diff --git a/torchcfm/version.py b/torchcfm/version.py
index 976498a..92192ee 100644
--- a/torchcfm/version.py
+++ b/torchcfm/version.py
@@ -1 +1 @@
-__version__ = "1.0.3"
+__version__ = "1.0.4"