diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index 6b4069f..81ac954 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -11,11 +11,6 @@ LRSCoupling, LinearRQSCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow -from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar -from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.radial import Radial -from torchflows.bijections.finite.residual.sylvester import Sylvester from torchflows.utils import get_batch_shape from test.constants import __test_constants diff --git a/test/test_invert_classic_residual.py b/test/test_invert_classic_residual.py new file mode 100644 index 0000000..f6656f4 --- /dev/null +++ b/test/test_invert_classic_residual.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torchflows.flows import Flow +from torchflows.bijections.finite.residual.architectures import RadialFlow, SylvesterFlow, PlanarFlow + + +@pytest.mark.parametrize( + 'architecture_class', + [ + RadialFlow, + SylvesterFlow, + PlanarFlow + ] +) +def test_basic(architecture_class): + torch.manual_seed(0) + event_shape = (1, 2, 3, 4) + batch_shape = (5, 6) + + flow = Flow(architecture_class(event_shape)) + x_new = flow.sample(batch_shape) + assert x_new.shape == (*batch_shape, *event_shape) + + flow.bijection.invert() + assert flow.log_prob(x_new).shape == batch_shape diff --git a/torchflows/__init__.py b/torchflows/__init__.py index 6d3802c..1a7460a 100644 --- a/torchflows/__init__.py +++ b/torchflows/__init__.py @@ -46,7 +46,7 @@ 'ProximalResFlow', 'Radial', 'Planar', - 'Sylvester', + 'InverseSylvester', 'ElementwiseShift', 'ElementwiseAffine', 'ElementwiseRQSpline', diff --git a/torchflows/bijections/base.py b/torchflows/bijections/base.py index 57c0ca6..23b0164 100644 --- a/torchflows/bijections/base.py +++ b/torchflows/bijections/base.py @@ -98,6 +98,9 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor def regularization(self): return 0.0 + def invert(self): + self.forward, self.inverse = self.inverse, self.forward + def invert(bijection: Bijection) -> Bijection: """Swap the forward and inverse methods of the input bijection. diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index d14e392..9d363bc 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -2,59 +2,49 @@ import torch -from torchflows.bijections.base import BijectiveComposition -from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine -from torchflows.bijections.finite.residual.base import ResidualComposition +from torchflows.bijections.finite.residual.base import ResidualArchitecture from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar, InversePlanar +from torchflows.bijections.finite.residual.planar import Planar from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester -class InvertibleResNet(ResidualComposition): +class InvertibleResNet(ResidualArchitecture): """Invertible residual network (i-ResNet) architecture. Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, InvertibleResNetBlock, **kwargs) -class ResFlow(ResidualComposition): +class ResFlow(ResidualArchitecture): """Residual flow (ResFlow) architecture. Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, ResFlowBlock, **kwargs) -class ProximalResFlow(ResidualComposition): +class ProximalResFlow(ResidualArchitecture): """Proximal residual flow architecture. Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + if 'layer_kwargs' not in kwargs: + kwargs['layer_kwargs'] = {} + if 'gamma' not in kwargs['layer_kwargs']: + kwargs['layer_kwargs']['gamma'] = 0.01 + super().__init__(event_shape, ProximalResFlowBlock, **kwargs) -class PlanarFlow(BijectiveComposition): +class PlanarFlow(ResidualArchitecture): """Planar flow architecture. Note: this model currently supports only one-way transformations. @@ -64,18 +54,11 @@ class PlanarFlow(BijectiveComposition): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - n_layers: int = 2, - inverse: bool = True): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[(InversePlanar if inverse else Planar)(event_shape) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) - - -class RadialFlow(BijectiveComposition): + **kwargs): + super().__init__(event_shape, Planar, **kwargs) + + +class RadialFlow(ResidualArchitecture): """Radial flow architecture. Note: this model currently supports only one-way transformations. @@ -83,17 +66,13 @@ class RadialFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[Radial(event_shape) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + **kwargs): + super().__init__(event_shape, Radial, **kwargs) -class SylvesterFlow(BijectiveComposition): +class SylvesterFlow(ResidualArchitecture): """Sylvester flow architecture. Note: this model currently supports only one-way transformations. @@ -101,11 +80,7 @@ class SylvesterFlow(BijectiveComposition): Reference: Van den Berg et al. "Sylvester Normalizing Flows for Variational Inference" (2019); https://arxiv.org/abs/1803.05649. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[Sylvester(event_shape, **kwargs) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + **kwargs): + super().__init__(event_shape, Sylvester, **kwargs) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index 55b0e61..9ac4e13 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, List +from typing import Union, Tuple, List, Type import torch @@ -7,8 +7,18 @@ from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch +class ClassicResidualBijection(Bijection): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + inverse: bool = False, + **kwargs): + super().__init__(event_shape, **kwargs) + if inverse: + self.invert() + + class ResidualBijection(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): """ g maps from (*batch_shape, n_event_dims) to (*batch_shape, n_event_dims) @@ -65,18 +75,24 @@ def inverse(self, return x, log_det -class ResidualComposition(BijectiveComposition): - def __init__(self, blocks: List[ResidualBijection]): - assert len(blocks) > 0 - event_shape = blocks[0].event_shape +class ResidualArchitecture(BijectiveComposition): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size], + layer_class: Type[Union[ResidualBijection, ClassicResidualBijection]], + n_layers: int = 2, + layer_kwargs: dict = None, + **kwargs): + assert n_layers > 0 + layer_kwargs = layer_kwargs or {} - updated_layers = [ElementwiseAffine(event_shape)] - for i in range(len(blocks)): - updated_layers.append(blocks[i]) - updated_layers.append(ElementwiseAffine(event_shape)) + layers = [ElementwiseAffine(event_shape)] + for i in range(n_layers): + layers.append(layer_class(event_shape, **layer_kwargs)) + layers.append(ElementwiseAffine(event_shape)) super().__init__( - event_shape=updated_layers[0].event_shape, - layers=updated_layers, - context_shape=updated_layers[0].context_shape + event_shape=layers[0].event_shape, + layers=layers, + context_shape=layers[0].context_shape, + **kwargs ) diff --git a/torchflows/bijections/finite/residual/planar.py b/torchflows/bijections/finite/residual/planar.py index 1edc957..80e061b 100644 --- a/torchflows/bijections/finite/residual/planar.py +++ b/torchflows/bijections/finite/residual/planar.py @@ -2,25 +2,13 @@ import torch import torch.nn as nn -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.utils import get_batch_shape -class Planar(Bijection): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.inv_planar = InversePlanar(*args, **kwargs) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.inv_planar.inverse(z=x, context=context) - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.inv_planar.forward(x=z, context=context) - - -class InversePlanar(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape) +class Planar(ClassicResidualBijection): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.w = nn.Parameter(torch.randn(size=(self.n_dim,))) self.u = nn.Parameter(torch.randn(size=(self.n_dim,))) self.b = nn.Parameter(torch.randn(size=())) diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 385e291..47eb1c3 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -4,15 +4,15 @@ import torch.nn as nn from torch.nn.functional import softplus -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.utils import get_batch_shape -class Radial(Bijection): +class Radial(ClassicResidualBijection): # as per Rezende, Mohamed (2015) - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.beta = nn.Parameter(torch.randn(size=())) self.unconstrained_alpha = nn.Parameter(torch.randn(size=())) self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,))) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 5208448..9a39aca 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -3,17 +3,18 @@ import torch import torch.nn as nn -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.bijections.matrices import UpperTriangularInvertibleMatrix, HouseholderOrthogonalMatrix, \ IdentityMatrix, PermutationMatrix from torchflows.utils import get_batch_shape -class BaseSylvester(Bijection): +class BaseSylvester(ClassicResidualBijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - m: int = None): - super().__init__(event_shape) + m: int = None, + **kwargs): + super().__init__(event_shape, **kwargs) self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) if m is None: @@ -75,14 +76,14 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class HouseholderSylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m) class IdentitySylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = IdentityMatrix(n_dim=self.n_dim) @@ -90,6 +91,6 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = Non class PermutationSylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = PermutationMatrix(n_dim=self.n_dim)