Skip to content

Commit

Permalink
Refactor classic residual bijections
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 1, 2024
1 parent a510134 commit 3bc85b0
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 104 deletions.
5 changes: 0 additions & 5 deletions test/test_autograd_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
LRSCoupling, LinearRQSCoupling
from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from torchflows.bijections.finite.residual.planar import Planar
from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock
from torchflows.bijections.finite.residual.radial import Radial
from torchflows.bijections.finite.residual.sylvester import Sylvester
from torchflows.utils import get_batch_shape
from test.constants import __test_constants

Expand Down
26 changes: 26 additions & 0 deletions test/test_invert_classic_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch

from torchflows.flows import Flow
from torchflows.bijections.finite.residual.architectures import RadialFlow, SylvesterFlow, PlanarFlow


@pytest.mark.parametrize(
'architecture_class',
[
RadialFlow,
SylvesterFlow,
PlanarFlow
]
)
def test_basic(architecture_class):
torch.manual_seed(0)
event_shape = (1, 2, 3, 4)
batch_shape = (5, 6)

flow = Flow(architecture_class(event_shape))
x_new = flow.sample(batch_shape)
assert x_new.shape == (*batch_shape, *event_shape)

flow.bijection.invert()
assert flow.log_prob(x_new).shape == batch_shape
2 changes: 1 addition & 1 deletion torchflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
'ProximalResFlow',
'Radial',
'Planar',
'Sylvester',
'InverseSylvester',
'ElementwiseShift',
'ElementwiseAffine',
'ElementwiseRQSpline',
Expand Down
3 changes: 3 additions & 0 deletions torchflows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor
def regularization(self):
return 0.0

def invert(self):
self.forward, self.inverse = self.inverse, self.forward


def invert(bijection: Bijection) -> Bijection:
"""Swap the forward and inverse methods of the input bijection.
Expand Down
85 changes: 30 additions & 55 deletions torchflows/bijections/finite/residual/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,49 @@

import torch

from torchflows.bijections.base import BijectiveComposition
from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine
from torchflows.bijections.finite.residual.base import ResidualComposition
from torchflows.bijections.finite.residual.base import ResidualArchitecture
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock
from torchflows.bijections.finite.residual.planar import Planar, InversePlanar
from torchflows.bijections.finite.residual.planar import Planar
from torchflows.bijections.finite.residual.radial import Radial
from torchflows.bijections.finite.residual.sylvester import Sylvester


class InvertibleResNet(ResidualComposition):
class InvertibleResNet(ResidualArchitecture):
"""Invertible residual network (i-ResNet) architecture.
Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995.
"""

def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs)
for _ in range(n_layers)
]
super().__init__(blocks)
def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, InvertibleResNetBlock, **kwargs)


class ResFlow(ResidualComposition):
class ResFlow(ResidualArchitecture):
"""Residual flow (ResFlow) architecture.
Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735.
"""

def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs)
for _ in range(n_layers)
]
super().__init__(blocks)
def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, ResFlowBlock, **kwargs)


class ProximalResFlow(ResidualComposition):
class ProximalResFlow(ResidualArchitecture):
"""Proximal residual flow architecture.
Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158.
"""

def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs)
for _ in range(n_layers)
]
super().__init__(blocks)
def __init__(self, event_shape, **kwargs):
if 'layer_kwargs' not in kwargs:
kwargs['layer_kwargs'] = {}
if 'gamma' not in kwargs['layer_kwargs']:
kwargs['layer_kwargs']['gamma'] = 0.01
super().__init__(event_shape, ProximalResFlowBlock, **kwargs)


class PlanarFlow(BijectiveComposition):
class PlanarFlow(ResidualArchitecture):
"""Planar flow architecture.
Note: this model currently supports only one-way transformations.
Expand All @@ -64,48 +54,33 @@ class PlanarFlow(BijectiveComposition):

def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_layers: int = 2,
inverse: bool = True):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
ElementwiseAffine(event_shape),
*[(InversePlanar if inverse else Planar)(event_shape) for _ in range(n_layers)],
ElementwiseAffine(event_shape)
])


class RadialFlow(BijectiveComposition):
**kwargs):
super().__init__(event_shape, Planar, **kwargs)


class RadialFlow(ResidualArchitecture):
"""Radial flow architecture.
Note: this model currently supports only one-way transformations.
Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
ElementwiseAffine(event_shape),
*[Radial(event_shape) for _ in range(n_layers)],
ElementwiseAffine(event_shape)
])
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
**kwargs):
super().__init__(event_shape, Radial, **kwargs)


class SylvesterFlow(BijectiveComposition):
class SylvesterFlow(ResidualArchitecture):
"""Sylvester flow architecture.
Note: this model currently supports only one-way transformations.
Reference: Van den Berg et al. "Sylvester Normalizing Flows for Variational Inference" (2019); https://arxiv.org/abs/1803.05649.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
ElementwiseAffine(event_shape),
*[Sylvester(event_shape, **kwargs) for _ in range(n_layers)],
ElementwiseAffine(event_shape)
])
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
**kwargs):
super().__init__(event_shape, Sylvester, **kwargs)
42 changes: 29 additions & 13 deletions torchflows/bijections/finite/residual/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Tuple, List
from typing import Union, Tuple, List, Type

import torch

Expand All @@ -7,8 +7,18 @@
from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch


class ClassicResidualBijection(Bijection):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
inverse: bool = False,
**kwargs):
super().__init__(event_shape, **kwargs)
if inverse:
self.invert()


class ResidualBijection(Bijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
"""
g maps from (*batch_shape, n_event_dims) to (*batch_shape, n_event_dims)
Expand Down Expand Up @@ -65,18 +75,24 @@ def inverse(self,
return x, log_det


class ResidualComposition(BijectiveComposition):
def __init__(self, blocks: List[ResidualBijection]):
assert len(blocks) > 0
event_shape = blocks[0].event_shape
class ResidualArchitecture(BijectiveComposition):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
layer_class: Type[Union[ResidualBijection, ClassicResidualBijection]],
n_layers: int = 2,
layer_kwargs: dict = None,
**kwargs):
assert n_layers > 0
layer_kwargs = layer_kwargs or {}

updated_layers = [ElementwiseAffine(event_shape)]
for i in range(len(blocks)):
updated_layers.append(blocks[i])
updated_layers.append(ElementwiseAffine(event_shape))
layers = [ElementwiseAffine(event_shape)]
for i in range(n_layers):
layers.append(layer_class(event_shape, **layer_kwargs))
layers.append(ElementwiseAffine(event_shape))

super().__init__(
event_shape=updated_layers[0].event_shape,
layers=updated_layers,
context_shape=updated_layers[0].context_shape
event_shape=layers[0].event_shape,
layers=layers,
context_shape=layers[0].context_shape,
**kwargs
)
20 changes: 4 additions & 16 deletions torchflows/bijections/finite/residual/planar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,13 @@
import torch
import torch.nn as nn

from torchflows.bijections.base import Bijection
from torchflows.bijections.finite.residual.base import ClassicResidualBijection
from torchflows.utils import get_batch_shape


class Planar(Bijection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inv_planar = InversePlanar(*args, **kwargs)

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
return self.inv_planar.inverse(z=x, context=context)

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
return self.inv_planar.forward(x=z, context=context)


class InversePlanar(Bijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]):
super().__init__(event_shape)
class Planar(ClassicResidualBijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)
self.w = nn.Parameter(torch.randn(size=(self.n_dim,)))
self.u = nn.Parameter(torch.randn(size=(self.n_dim,)))
self.b = nn.Parameter(torch.randn(size=()))
Expand Down
8 changes: 4 additions & 4 deletions torchflows/bijections/finite/residual/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch.nn as nn
from torch.nn.functional import softplus

from torchflows.bijections.base import Bijection
from torchflows.bijections.finite.residual.base import ClassicResidualBijection
from torchflows.utils import get_batch_shape


class Radial(Bijection):
class Radial(ClassicResidualBijection):
# as per Rezende, Mohamed (2015)

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]):
super().__init__(event_shape)
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)
self.beta = nn.Parameter(torch.randn(size=()))
self.unconstrained_alpha = nn.Parameter(torch.randn(size=()))
self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,)))
Expand Down
21 changes: 11 additions & 10 deletions torchflows/bijections/finite/residual/sylvester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import torch
import torch.nn as nn

from torchflows.bijections.base import Bijection
from torchflows.bijections.finite.residual.base import ClassicResidualBijection
from torchflows.bijections.matrices import UpperTriangularInvertibleMatrix, HouseholderOrthogonalMatrix, \
IdentityMatrix, PermutationMatrix
from torchflows.utils import get_batch_shape


class BaseSylvester(Bijection):
class BaseSylvester(ClassicResidualBijection):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
m: int = None):
super().__init__(event_shape)
m: int = None,
**kwargs):
super().__init__(event_shape, **kwargs)
self.n_dim = int(torch.prod(torch.as_tensor(event_shape)))

if m is None:
Expand Down Expand Up @@ -75,21 +76,21 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.


class HouseholderSylvester(BaseSylvester):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None):
super().__init__(event_shape, m)
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)
self.q = HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m)


class IdentitySylvester(BaseSylvester):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None):
super().__init__(event_shape, m)
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)
self.q = IdentityMatrix(n_dim=self.n_dim)


Sylvester = IdentitySylvester


class PermutationSylvester(BaseSylvester):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None):
super().__init__(event_shape, m)
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)
self.q = PermutationMatrix(n_dim=self.n_dim)

0 comments on commit 3bc85b0

Please sign in to comment.