Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #18

Merged
merged 32 commits into from
Aug 14, 2024
Merged

Dev #18

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cd01d8b
Update docs
davidnabergoj Aug 13, 2024
f91f14f
Update docs
davidnabergoj Aug 13, 2024
8fa97ef
Update rtd
davidnabergoj Aug 13, 2024
98fc157
Add example notebook
davidnabergoj Aug 13, 2024
5be8350
Remove identity initialization
davidnabergoj Aug 13, 2024
6bd0099
Add documentation
davidnabergoj Aug 13, 2024
90d90a6
Update index.rst
davidnabergoj Aug 13, 2024
614816b
Add notebooks
davidnabergoj Aug 13, 2024
3a61063
Modify sphinx files
davidnabergoj Aug 13, 2024
3c2349c
Remove __init__.py files in subdirectories, change imports accordingly
davidnabergoj Aug 13, 2024
01e7ef5
Remove __init__.py files in subdirectories
davidnabergoj Aug 13, 2024
3186153
Update dcs
davidnabergoj Aug 14, 2024
eb12968
Update docs
davidnabergoj Aug 14, 2024
8e6c76c
Add requirements.txt for sphinx
davidnabergoj Aug 14, 2024
73cad4b
Add torchflows to requirements.txt for sphinx
davidnabergoj Aug 14, 2024
cd23c3c
Fix typo, remove torchflows from requirements.txt
davidnabergoj Aug 14, 2024
7eb1d93
Add torchflows to requirements.txt
davidnabergoj Aug 14, 2024
d87bce3
Fix underline length
davidnabergoj Aug 14, 2024
7f80493
Add copy button
davidnabergoj Aug 14, 2024
f65ca4c
Update requirements.txt
davidnabergoj Aug 14, 2024
b8a7001
Update docs
davidnabergoj Aug 14, 2024
61a9043
Add references and docstrings for autoregressive flows
davidnabergoj Aug 14, 2024
b8d79c9
Rename headers
davidnabergoj Aug 14, 2024
3cd5804
Add bijection docs
davidnabergoj Aug 14, 2024
e2090b2
Use section in rst
davidnabergoj Aug 14, 2024
cf770f7
Add continuous NF docs
davidnabergoj Aug 14, 2024
59a4e53
Add continuous bijection docs
davidnabergoj Aug 14, 2024
f5cdad2
Separate architectures by type
davidnabergoj Aug 14, 2024
eff1a15
Add residual flow docs
davidnabergoj Aug 14, 2024
94b88c4
Update docs
davidnabergoj Aug 14, 2024
a489c46
Update docs
davidnabergoj Aug 14, 2024
8e892ea
Merge pull request #17 from davidnabergoj/docs
davidnabergoj Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build:
python: "3.11"

sphinx:
configuration: docs/conf.py
configuration: docs/source/conf.py

python:
install:
Expand Down
66 changes: 7 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,20 @@ print(log_prob.shape) # (100,)
print(x_new.shape) # (50, 3)
```

We provide more examples [here](examples/).
Check examples and documentation, including the list of supported architectures [here](torchflows.readthedocs.io/en/latest/).
We also provide examples [here](examples/).

## Installing

Install via pip:
We support Python versions 3.7 and upwards.

Install Torchflows via pip:

```
pip install torchflows
```

Install the package directly from Github:
Install Torchflows directly from Github:

```
pip install git+https://github.com/davidnabergoj/torchflows.git
Expand All @@ -53,59 +57,3 @@ cd torchflows
pip install -r requirements.txt
```

We support Python versions 3.7 and upwards.

## Brief background

A normalizing flow (NF) is a flexible trainable distribution.
It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian.
The bijection is typically an invertible neural network.
Training a NF using a dataset means optimizing the bijection's parameters to make the dataset likely under the NF.
We can use a NF to compute the probability of a data point or to independently sample data from the process that
generated our dataset.

The density of a NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as:
$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|.$$
Sampling from a NF means sampling from the simple distribution and transforming the sample using the bijection.

## Supported architectures

We list supported NF architectures below.
We classify architectures as either autoregressive, residual, or continuous; as defined
by [Papamakarios et al. (2021)](https://arxiv.org/abs/1912.02762).
We specify whether the forward and inverse passes are exact; otherwise they are numerical or not implemented (Planar,
Radial, and Sylvester flows).
An exact forward pass guarantees exact density estimation, whereas an exact inverse pass guarantees exact sampling.
Note that the directions can always be reversed, which enables exact computation for the opposite task.
We also specify whether the logarithm of the Jacobian determinant of the transformation is exact or computed numerically.

| Architecture | Bijection type | Exact forward | Exact inverse | Exact log determinant |
|--------------------------------------------------------------------------|:--------------------------:|:---------------:|:-------------:|:---------------------:|
| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive | ✔ | ✔ | ✔ |
| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive | ✔ | ✔ | ✔ |
| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive | ✔ | ✔ | ✔ |
| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive | ✔ | ✔ | ✔ |
| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive | ✔ | ✔ | ✔ |
| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive | ✔ | ✔ | ✔ |
| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive | ✔ | ✗ | ✔ |
| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive | ✗ | ✗ | ✔ |
| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual | ✔ | ✗ | ✔ |
| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual | ✔ | ✗ | ✔ |
| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual | ✔ | ✗ | ✔ |
| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual | ✔ | ✗ | ✗ |
| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual | ✔ | ✗ | ✗ |
| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual | ✔ | ✗ | ✗ |
| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous | ✗ | ✗ | ✗ |
| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous | ✗ | ✗ | ✗ |
| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous | ✗ | ✗ | ✗ |
| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous | ✗ | ✗ | ✗ |


We also support simple bijections (all with exact forward passes, inverse passes, and log determinants):

* Permutation
* Elementwise translation (shift vector)
* Elementwise scaling (diagonal matrix)
* Rotation (orthogonal matrix)
* Triangular matrix
* Dense matrix (using the QR or LU decomposition)
2 changes: 1 addition & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
35 changes: 0 additions & 35 deletions docs/conf.py

This file was deleted.

8 changes: 4 additions & 4 deletions docs/make.bat
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ if "%SPHINXBUILD%" == "" (
set SOURCEDIR=source
set BUILDDIR=build

if "%1" == "" goto help

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
Expand All @@ -21,15 +19,17 @@ if errorlevel 9009 (
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
echo.https://www.sphinx-doc.org/
exit /b 1
)

if "%1" == "" goto help

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
popd
90 changes: 90 additions & 0 deletions docs/notebooks/computing_log_determinants.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# Computing the log determinant of the Jacobian\n",
"\n",
"We show how to compute and retrieve the log determinant of the Jacobian of a bijective transformation. We use Real NVP as an example."
],
"id": "624f99599895f0fd"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T16:39:40.646799Z",
"start_time": "2024-08-13T16:39:38.868039Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"from torchflows import Flow\n",
"from torchflows.architectures import RealNVP\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"batch_shape = (5, 7)\n",
"event_shape = (2, 3)\n",
"x = torch.randn(size=(*batch_shape, *event_shape))\n",
"z = torch.randn(size=(*batch_shape, *event_shape))\n",
"\n",
"bijection = RealNVP(event_shape=event_shape)\n",
"flow = Flow(bijection)\n",
"\n",
"_, log_det_forward = flow.bijection.forward(x)\n",
"_, log_det_inverse = flow.bijection.inverse(z)"
],
"id": "3f74b61a9929dd3b",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T16:39:40.662420Z",
"start_time": "2024-08-13T16:39:40.653696Z"
}
},
"cell_type": "code",
"source": [
"print(f'{log_det_forward.shape = }')\n",
"print(f'{log_det_inverse.shape = }')"
],
"id": "3c49e132d9c041c2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"log_det_forward.shape = torch.Size([5, 7])\n",
"log_det_inverse.shape = torch.Size([5, 7])\n"
]
}
],
"execution_count": 2
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
133 changes: 133 additions & 0 deletions docs/notebooks/image_modeling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# Image modeling with normalizing flows\n",
"\n",
"When working with images, we can use specialized multiscale flow architectures. We can also use standard normalizing flows, which internally work with a flattened image. Note that multiscale architectures expect input images with shape `(channels, height, width)`."
],
"id": "df68afe10da259a1"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:20:05.803231Z",
"start_time": "2024-08-13T17:20:03.001656Z"
}
},
"cell_type": "code",
"source": [
"from torchvision.datasets import MNIST\n",
"import torch\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"# pip install torchvision\n",
"dataset = MNIST(root='./data', download=True, train=True)\n",
"train_data = dataset.data.float()[:, None]\n",
"train_data = train_data[torch.randperm(len(train_data))]\n",
"train_data = (train_data - torch.mean(train_data)) / torch.std(train_data)\n",
"x_train, x_val = train_data[:1000], train_data[1000:1200]\n",
"\n",
"print(f'{x_train.shape = }')\n",
"print(f'{x_val.shape = }')\n",
"\n",
"image_shape = train_data.shape[1:]\n",
"print(f'{image_shape = }')"
],
"id": "b4d5e1888ff6a0e7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train.shape = torch.Size([1000, 1, 28, 28])\n",
"x_val.shape = torch.Size([200, 1, 28, 28])\n",
"image_shape = torch.Size([1, 28, 28])\n"
]
}
],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:20:06.058329Z",
"start_time": "2024-08-13T17:20:05.891695Z"
}
},
"cell_type": "code",
"source": [
"from torchflows.flows import Flow\n",
"from torchflows.architectures import RealNVP, MultiscaleRealNVP\n",
"\n",
"real_nvp = Flow(RealNVP(image_shape))\n",
"multiscale_real_nvp = Flow(MultiscaleRealNVP(image_shape))"
],
"id": "744513899ffa6a46",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:26:11.651540Z",
"start_time": "2024-08-13T17:20:06.378393Z"
}
},
"cell_type": "code",
"source": [
"real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)\n",
"multiscale_real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)"
],
"id": "7a439e2565ce5a25",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fitting NF: 30%|███ | 151/500 [00:18<00:42, 8.30it/s, Training loss (batch): -0.2608, Validation loss: 1.3448 [best: 0.1847 @ 100]] \n",
"Fitting NF: 30%|███ | 152/500 [05:47<13:14, 2.28s/it, Training loss (batch): -0.3050, Validation loss: 0.9754 [best: 0.1744 @ 101]] \n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:26:11.699539Z",
"start_time": "2024-08-13T17:26:11.686539Z"
}
},
"cell_type": "code",
"source": "",
"id": "c38fc6cc58bdc0b2",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading