diff --git a/test/test_sampling.py b/test/test_sampling.py new file mode 100644 index 0000000..2d78772 --- /dev/null +++ b/test/test_sampling.py @@ -0,0 +1,11 @@ +import pytest + +from torchflows import Flow +from torchflows.architectures import PlanarFlow, SylvesterFlow, RadialFlow + + +@pytest.mark.parametrize('arch_cls', [PlanarFlow, SylvesterFlow, RadialFlow]) +def test_basic(arch_cls): + event_shape = (1, 2, 3, 4) + f = Flow(arch_cls(event_shape=event_shape)) + assert f.sample((10,)).shape == (10, *event_shape) diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index acfb1a3..d14e392 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -3,11 +3,11 @@ import torch from torchflows.bijections.base import BijectiveComposition -from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine from torchflows.bijections.finite.residual.base import ResidualComposition from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar +from torchflows.bijections.finite.residual.planar import Planar, InversePlanar from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester @@ -17,6 +17,7 @@ class InvertibleResNet(ResidualComposition): Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. """ + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) @@ -30,6 +31,7 @@ class ResFlow(ResidualComposition): Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. """ + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) @@ -43,6 +45,7 @@ class ProximalResFlow(ResidualComposition): Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158. """ + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs) @@ -58,13 +61,17 @@ class PlanarFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + n_layers: int = 2, + inverse: bool = True): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), - *[Planar(event_shape) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape), + *[(InversePlanar if inverse else Planar)(event_shape) for _ in range(n_layers)], + ElementwiseAffine(event_shape) ]) @@ -75,13 +82,14 @@ class RadialFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), + ElementwiseAffine(event_shape), *[Radial(event_shape) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape) ]) @@ -92,11 +100,12 @@ class SylvesterFlow(BijectiveComposition): Reference: Van den Berg et al. "Sylvester Normalizing Flows for Variational Inference" (2019); https://arxiv.org/abs/1803.05649. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), + ElementwiseAffine(event_shape), *[Sylvester(event_shape, **kwargs) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape) ]) diff --git a/torchflows/bijections/finite/residual/planar.py b/torchflows/bijections/finite/residual/planar.py index b6d59d3..1edc957 100644 --- a/torchflows/bijections/finite/residual/planar.py +++ b/torchflows/bijections/finite/residual/planar.py @@ -41,10 +41,10 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. z = z.view(*batch_shape, self.n_dim) # x = z + u * self.h(w.T @ z + self.b) - x = z + u * self.h(torch.einsum('...i,...i', w, z) + self.b) + x = z + u * self.h(torch.einsum('...i,...i', w, z) + self.b)[..., None] # phi = self.h_deriv(w.T @ z + self.b) * w - phi = self.h_deriv(torch.einsum('...i,...i', w, z) + self.b) * w + phi = w * self.h_deriv(torch.einsum('...i,...i', w, z) + self.b)[..., None] # log_det = torch.log(torch.abs(1 + u.T @ phi)) log_det = torch.log(torch.abs(1 + torch.einsum('...i,...i', u, phi))) diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 736b320..385e291 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -30,7 +30,7 @@ def h(self, z): def h_deriv(self, z): batch_shape = z.shape[:-1] z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - sign = (-1.0) ** torch.where(z - z0 < 0)[0] + sign = (-1.0) ** torch.less(z, z0).float() return -(self.h(z) ** 2) * sign * z def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -54,7 +54,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. log_det = torch.abs(torch.add( (self.n_dim - 1) * torch.log1p(beta_times_h_val), torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - )) + )).sum(dim=-1) x = x.view(*batch_shape, *self.event_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 24ba10e..5208448 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -18,6 +18,8 @@ def __init__(self, if m is None: m = self.n_dim // 2 + if m > self.n_dim: + raise ValueError self.m = m self.b = nn.Parameter(torch.randn(m)) @@ -29,13 +31,13 @@ def __init__(self, @property def w(self): r_tilde = self.r_tilde.mat() - q = self.q.mat() + q = self.q.mat()[:, :self.m] return torch.einsum('...ij,...kj->...ik', r_tilde, q) @property def u(self): r = self.r.mat() - q = self.q.mat() + q = self.q.mat()[:, :self.m] return torch.einsum('...ij,...jk->...ik', q, r) def h(self, x): @@ -49,21 +51,21 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) + z_flat = torch.flatten(z, start_dim=len(batch_shape)) u = self.u.view(*([1] * len(batch_shape)), *self.u.shape) w = self.w.view(*([1] * len(batch_shape)), *self.w.shape) b = self.b.view(*([1] * len(batch_shape)), *self.b.shape) - wzpb = torch.einsum('...ij,...j->...i', w, z) + b # (..., m) + wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) - z = z.view(*batch_shape, self.n_dim) - x = z + torch.einsum( + x = z_flat + torch.einsum( '...ij,...j->...i', u, self.h(wzpb) ) wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) - diag = torch.zeros(size=(batch_shape, self.m, self.m)) + diag = torch.zeros(size=(*batch_shape, self.m, self.m)) diag[..., range(self.m), range(self.m)] = self.h_deriv(wzpb) # (..., m, m) _, log_det = torch.linalg.slogdet(torch.eye(self.m) + torch.einsum('...ij,...jk->...ik', diag, wu))