diff --git a/README.md b/README.md index e3931fd..79856b2 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ import torch from normalizing_flows import Flow from normalizing_flows.architectures import RealNVP - torch.manual_seed(0) n_data = 1000 @@ -69,35 +68,35 @@ Sampling from a NF means sampling from the simple distribution and transforming We list supported NF architectures below. We classify architectures as either autoregressive, residual, or continuous; as defined by [Papamakarios et al. (2021)](https://arxiv.org/abs/1912.02762). -Exact architectures do not use numerical approximations to generate data or compute the log density. - -| Architecture | Bijection type | Exact | Two-way | -|--------------------------------------------------------------------------|:--------------------------:|:-------:|:-------:| -| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive | ✔ | ✔ | -| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive | ✔ | ✔ | -| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive | ✔ | ✔ | -| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive | ✔ | ✔ | -| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive | ✔ | ✔ | -| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive | ✔ | ✔ | -| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive | ✗ | ✔ | -| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive | ✗ | ✔ | -| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual | ✗ | ✗ | -| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual | ✗ | ✗ | -| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual | ✗ | ✗ | -| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual | ✗ | ✔* | -| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual | ✗ | ✔* | -| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual | ✗ | ✔* | -| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous | ✗ | ✔* | -| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous | ✗ | ✔* | -| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous | ✗ | ✔* | -| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous | ✗ | ✔ | - -Two-way architectures support both sampling and density estimation. -Two-way architectures marked with an asterisk (*) support both, but use a numerical approximation to sample or estimate -density. -One-way architectures support either sampling or density estimation, but not both at once. - -We also support simple bijections (all exact and two-way): +We specify whether the forward and inverse passes are exact; otherwise they are numerical or not implemented (Planar, +Radial, and Sylvester flows). +An exact forward pass guarantees exact density estimation, whereas an exact inverse pass guarantees exact sampling. +Note that the directions can always be reversed, which enables exact computation for the opposite task. +We also specify whether the logarithm of the Jacobian determinant of the transformation is exact or computed numerically. + +| Architecture | Bijection type | Exact forward | Exact inverse | Exact log determinant | +|--------------------------------------------------------------------------|:--------------------------:|:---------------:|:-------------:|:---------------------:| +| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive | ✔ | ✔ | ✔ | +| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive | ✔ | ✔ | ✔ | +| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive | ✔ | ✔ | ✔ | +| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive | ✔ | ✔ | ✔ | +| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive | ✔ | ✔ | ✔ | +| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive | ✔ | ✔ | ✔ | +| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive | ✔ | ✗ | ✔ | +| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive | ✗ | ✗ | ✔ | +| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual | ✔ | ✗ | ✔ | +| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual | ✔ | ✗ | ✔ | +| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual | ✔ | ✗ | ✔ | +| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual | ✔ | ✗ | ✗ | +| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual | ✔ | ✗ | ✗ | +| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual | ✔ | ✗ | ✗ | +| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous | ✗ | ✗ | ✗ | +| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous | ✗ | ✗ | ✗ | +| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous | ✗ | ✗ | ✗ | +| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous | ✗ | ✗ | ✗ | + + +We also support simple bijections (all with exact forward passes, inverse passes, and log determinants): * Permutation * Elementwise translation (shift vector) diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 83a0295..191a022 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -17,4 +17,11 @@ from normalizing_flows.bijections.continuous.ffjord import FFJORD from normalizing_flows.bijections.continuous.otflow import OTFlow -from normalizing_flows.bijections.finite.residual.architectures import ResFlow, ProximalResFlow, InvertibleResNet +from normalizing_flows.bijections.finite.residual.architectures import ( + ResFlow, + ProximalResFlow, + InvertibleResNet, + Planar, + Radial, + Sylvester +) diff --git a/normalizing_flows/bijections/finite/residual/architectures.py b/normalizing_flows/bijections/finite/residual/architectures.py index 8024068..2351987 100644 --- a/normalizing_flows/bijections/finite/residual/architectures.py +++ b/normalizing_flows/bijections/finite/residual/architectures.py @@ -1,6 +1,15 @@ +from typing import Union, Tuple + +import torch + +from normalizing_flows.bijections import Affine +from normalizing_flows.bijections.base import BijectiveComposition from normalizing_flows.bijections.finite.residual.base import ResidualComposition from normalizing_flows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from normalizing_flows.bijections.finite.residual.proximal import ProximalResFlowBlock +from normalizing_flows.bijections.finite.residual.planar import Planar +from normalizing_flows.bijections.finite.residual.radial import Radial +from normalizing_flows.bijections.finite.residual.sylvester import Sylvester class InvertibleResNet(ResidualComposition): @@ -22,3 +31,36 @@ def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs block = ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) blocks = [block for _ in range(n_layers)] # The same block super().__init__(blocks) + + +class PlanarFlow(BijectiveComposition): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): + if n_layers < 1: + raise ValueError(f"Flow needs at least one layer, but got {n_layers}") + super().__init__(event_shape, [ + Affine(event_shape), + *[Planar(event_shape) for _ in range(n_layers)], + Affine(event_shape) + ]) + + +class RadialFlow(BijectiveComposition): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): + if n_layers < 1: + raise ValueError(f"Flow needs at least one layer, but got {n_layers}") + super().__init__(event_shape, [ + Affine(event_shape), + *[Radial(event_shape) for _ in range(n_layers)], + Affine(event_shape) + ]) + + +class SylvesterFlow(BijectiveComposition): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs): + if n_layers < 1: + raise ValueError(f"Flow needs at least one layer, but got {n_layers}") + super().__init__(event_shape, [ + Affine(event_shape), + *[Sylvester(event_shape, **kwargs) for _ in range(n_layers)], + Affine(event_shape) + ])