Skip to content

Commit

Permalink
Merge pull request #17 from davidnabergoj/docs
Browse files Browse the repository at this point in the history
Docs
  • Loading branch information
davidnabergoj authored Aug 14, 2024
2 parents cd01d8b + a489c46 commit 8e892ea
Show file tree
Hide file tree
Showing 64 changed files with 1,270 additions and 271 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build:
python: "3.11"

sphinx:
configuration: docs/conf.py
configuration: docs/source/conf.py

python:
install:
Expand Down
66 changes: 7 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,20 @@ print(log_prob.shape) # (100,)
print(x_new.shape) # (50, 3)
```

We provide more examples [here](examples/).
Check examples and documentation, including the list of supported architectures [here](torchflows.readthedocs.io/en/latest/).
We also provide examples [here](examples/).

## Installing

Install via pip:
We support Python versions 3.7 and upwards.

Install Torchflows via pip:

```
pip install torchflows
```

Install the package directly from Github:
Install Torchflows directly from Github:

```
pip install git+https://github.com/davidnabergoj/torchflows.git
Expand All @@ -53,59 +57,3 @@ cd torchflows
pip install -r requirements.txt
```

We support Python versions 3.7 and upwards.

## Brief background

A normalizing flow (NF) is a flexible trainable distribution.
It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian.
The bijection is typically an invertible neural network.
Training a NF using a dataset means optimizing the bijection's parameters to make the dataset likely under the NF.
We can use a NF to compute the probability of a data point or to independently sample data from the process that
generated our dataset.

The density of a NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as:
$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|.$$
Sampling from a NF means sampling from the simple distribution and transforming the sample using the bijection.

## Supported architectures

We list supported NF architectures below.
We classify architectures as either autoregressive, residual, or continuous; as defined
by [Papamakarios et al. (2021)](https://arxiv.org/abs/1912.02762).
We specify whether the forward and inverse passes are exact; otherwise they are numerical or not implemented (Planar,
Radial, and Sylvester flows).
An exact forward pass guarantees exact density estimation, whereas an exact inverse pass guarantees exact sampling.
Note that the directions can always be reversed, which enables exact computation for the opposite task.
We also specify whether the logarithm of the Jacobian determinant of the transformation is exact or computed numerically.

| Architecture | Bijection type | Exact forward | Exact inverse | Exact log determinant |
|--------------------------------------------------------------------------|:--------------------------:|:---------------:|:-------------:|:---------------------:|
| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive ||||
| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive ||||
| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive ||||
| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive ||||
| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive ||||
| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive ||||
| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive ||||
| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive ||||
| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual ||||
| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual ||||
| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual ||||
| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual ||||
| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual ||||
| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual ||||
| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous ||||
| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous ||||
| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous ||||
| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous ||||


We also support simple bijections (all with exact forward passes, inverse passes, and log determinants):

* Permutation
* Elementwise translation (shift vector)
* Elementwise scaling (diagonal matrix)
* Rotation (orthogonal matrix)
* Triangular matrix
* Dense matrix (using the QR or LU decomposition)
2 changes: 1 addition & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
35 changes: 0 additions & 35 deletions docs/conf.py

This file was deleted.

8 changes: 4 additions & 4 deletions docs/make.bat
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ if "%SPHINXBUILD%" == "" (
set SOURCEDIR=source
set BUILDDIR=build

if "%1" == "" goto help

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
Expand All @@ -21,15 +19,17 @@ if errorlevel 9009 (
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
echo.https://www.sphinx-doc.org/
exit /b 1
)

if "%1" == "" goto help

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
popd
90 changes: 90 additions & 0 deletions docs/notebooks/computing_log_determinants.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# Computing the log determinant of the Jacobian\n",
"\n",
"We show how to compute and retrieve the log determinant of the Jacobian of a bijective transformation. We use Real NVP as an example."
],
"id": "624f99599895f0fd"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T16:39:40.646799Z",
"start_time": "2024-08-13T16:39:38.868039Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"from torchflows import Flow\n",
"from torchflows.architectures import RealNVP\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"batch_shape = (5, 7)\n",
"event_shape = (2, 3)\n",
"x = torch.randn(size=(*batch_shape, *event_shape))\n",
"z = torch.randn(size=(*batch_shape, *event_shape))\n",
"\n",
"bijection = RealNVP(event_shape=event_shape)\n",
"flow = Flow(bijection)\n",
"\n",
"_, log_det_forward = flow.bijection.forward(x)\n",
"_, log_det_inverse = flow.bijection.inverse(z)"
],
"id": "3f74b61a9929dd3b",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T16:39:40.662420Z",
"start_time": "2024-08-13T16:39:40.653696Z"
}
},
"cell_type": "code",
"source": [
"print(f'{log_det_forward.shape = }')\n",
"print(f'{log_det_inverse.shape = }')"
],
"id": "3c49e132d9c041c2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"log_det_forward.shape = torch.Size([5, 7])\n",
"log_det_inverse.shape = torch.Size([5, 7])\n"
]
}
],
"execution_count": 2
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
133 changes: 133 additions & 0 deletions docs/notebooks/image_modeling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# Image modeling with normalizing flows\n",
"\n",
"When working with images, we can use specialized multiscale flow architectures. We can also use standard normalizing flows, which internally work with a flattened image. Note that multiscale architectures expect input images with shape `(channels, height, width)`."
],
"id": "df68afe10da259a1"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:20:05.803231Z",
"start_time": "2024-08-13T17:20:03.001656Z"
}
},
"cell_type": "code",
"source": [
"from torchvision.datasets import MNIST\n",
"import torch\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"# pip install torchvision\n",
"dataset = MNIST(root='./data', download=True, train=True)\n",
"train_data = dataset.data.float()[:, None]\n",
"train_data = train_data[torch.randperm(len(train_data))]\n",
"train_data = (train_data - torch.mean(train_data)) / torch.std(train_data)\n",
"x_train, x_val = train_data[:1000], train_data[1000:1200]\n",
"\n",
"print(f'{x_train.shape = }')\n",
"print(f'{x_val.shape = }')\n",
"\n",
"image_shape = train_data.shape[1:]\n",
"print(f'{image_shape = }')"
],
"id": "b4d5e1888ff6a0e7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train.shape = torch.Size([1000, 1, 28, 28])\n",
"x_val.shape = torch.Size([200, 1, 28, 28])\n",
"image_shape = torch.Size([1, 28, 28])\n"
]
}
],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:20:06.058329Z",
"start_time": "2024-08-13T17:20:05.891695Z"
}
},
"cell_type": "code",
"source": [
"from torchflows.flows import Flow\n",
"from torchflows.architectures import RealNVP, MultiscaleRealNVP\n",
"\n",
"real_nvp = Flow(RealNVP(image_shape))\n",
"multiscale_real_nvp = Flow(MultiscaleRealNVP(image_shape))"
],
"id": "744513899ffa6a46",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:26:11.651540Z",
"start_time": "2024-08-13T17:20:06.378393Z"
}
},
"cell_type": "code",
"source": [
"real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)\n",
"multiscale_real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)"
],
"id": "7a439e2565ce5a25",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fitting NF: 30%|███ | 151/500 [00:18<00:42, 8.30it/s, Training loss (batch): -0.2608, Validation loss: 1.3448 [best: 0.1847 @ 100]] \n",
"Fitting NF: 30%|███ | 152/500 [05:47<13:14, 2.28s/it, Training loss (batch): -0.3050, Validation loss: 0.9754 [best: 0.1744 @ 101]] \n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-13T17:26:11.699539Z",
"start_time": "2024-08-13T17:26:11.686539Z"
}
},
"cell_type": "code",
"source": "",
"id": "c38fc6cc58bdc0b2",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 8e892ea

Please sign in to comment.