From 286563a746547df333676ee38df6565edcb00bab Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 31 Aug 2024 18:02:04 +0200 Subject: [PATCH 01/25] Add convolutional continuous architectures --- torchflows/architectures.py | 6 +++--- torchflows/bijections/continuous/base.py | 13 +++++++++++- torchflows/bijections/continuous/ddnf.py | 21 +++++++++++++++++-- torchflows/bijections/continuous/ffjord.py | 19 +++++++++++++++-- torchflows/bijections/continuous/rnode.py | 24 ++++++++++++++++++++-- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/torchflows/architectures.py b/torchflows/architectures.py index b8710d8..7f19356 100644 --- a/torchflows/architectures.py +++ b/torchflows/architectures.py @@ -22,9 +22,9 @@ UMNNMAF ) -from torchflows.bijections.continuous.ddnf import DeepDiffeomorphicBijection -from torchflows.bijections.continuous.rnode import RNODE -from torchflows.bijections.continuous.ffjord import FFJORD +from torchflows.bijections.continuous.ddnf import DeepDiffeomorphicBijection, ConvolutionalDeepDiffeomorphicBijection +from torchflows.bijections.continuous.rnode import RNODE, ConvolutionalRNODE +from torchflows.bijections.continuous.ffjord import FFJORD, ConvolutionalFFJORD from torchflows.bijections.continuous.otflow import OTFlow from torchflows.bijections.finite.residual.architectures import ( diff --git a/torchflows/bijections/continuous/base.py b/torchflows/bijections/continuous/base.py index 212873c..4b795c7 100644 --- a/torchflows/bijections/continuous/base.py +++ b/torchflows/bijections/continuous/base.py @@ -5,7 +5,7 @@ import torch.nn as nn from torchflows.bijections.base import Bijection -from torchflows.bijections.continuous.layers import DiffEqLayer +from torchflows.bijections.continuous.layers import DiffEqLayer, ConcatConv2d, IgnoreConv2d import torchflows.bijections.continuous.layers as diff_eq_layers from torchflows.utils import flatten_event, flatten_batch, get_batch_shape, unflatten_batch, unflatten_event @@ -18,6 +18,7 @@ # Based on: https://github.com/rtqichen/ffjord/blob/994864ad0517db3549717c25170f9b71e96788b1/lib/layers/cnf.py#L11 + def _flip(x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) @@ -86,6 +87,16 @@ def create_nn(event_size: int, hidden_size: int = None, n_hidden_layers: int = 2 return TimeDerivativeDNN(layers) +def create_cnn(c: int, n_layers: int = 2): + # c: number of image channels + return TimeDerivativeDNN([ConcatConv2d(c, c) for _ in range(n_layers)]) + + +def create_cnn_time_independent(c: int, n_layers: int = 2): + # c: number of image channels + return TimeDerivativeDNN([IgnoreConv2d(c, c) for _ in range(n_layers)]) + + class TimeDerivative(nn.Module): def __init__(self): super().__init__() diff --git a/torchflows/bijections/continuous/ddnf.py b/torchflows/bijections/continuous/ddnf.py index 4cbff88..2f60630 100644 --- a/torchflows/bijections/continuous/ddnf.py +++ b/torchflows/bijections/continuous/ddnf.py @@ -2,8 +2,12 @@ import torch -from torchflows.bijections.continuous.base import ApproximateContinuousBijection, \ - RegularizedApproximateODEFunction, create_nn_time_independent +from torchflows.bijections.continuous.base import ( + ApproximateContinuousBijection, + RegularizedApproximateODEFunction, + create_nn_time_independent, + create_cnn_time_independent +) class DeepDiffeomorphicBijection(ApproximateContinuousBijection): @@ -28,3 +32,16 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(n_dim)) self.n_steps = n_steps super().__init__(event_shape, diff_eq, solver=solver, **kwargs) + +class ConvolutionalDeepDiffeomorphicBijection(ApproximateContinuousBijection): + """Convolutional variant of the DDNF architecture. + + Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs): + if len(event_shape) != 3: + raise ValueError("Event shape must be of length 3 (channels, height, width).") + diff_eq = RegularizedApproximateODEFunction(create_cnn_time_independent(event_shape[0])) + self.n_steps = n_steps + super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/ffjord.py b/torchflows/bijections/continuous/ffjord.py index f5229ae..4f77f28 100644 --- a/torchflows/bijections/continuous/ffjord.py +++ b/torchflows/bijections/continuous/ffjord.py @@ -5,18 +5,33 @@ from torchflows.bijections.continuous.base import ( ApproximateContinuousBijection, create_nn, - RegularizedApproximateODEFunction + RegularizedApproximateODEFunction, + create_cnn ) # https://github.com/rtqichen/ffjord/blob/master/lib/layers/cnf.py class FFJORD(ApproximateContinuousBijection): - """ Free-form Jacobian of reversible dynamics (FFJORD) architecture. + """Free-form Jacobian of reversible dynamics (FFJORD) architecture. Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): n_dim = int(torch.prod(torch.as_tensor(event_shape))) diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim)) super().__init__(event_shape, diff_eq, **kwargs) + + +class ConvolutionalFFJORD(ApproximateContinuousBijection): + """Convolutional variant of the FFJORD architecture. + + Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + if len(event_shape) != 3: + raise ValueError("Event shape must be of length 3 (channels, height, width).") + diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0])) + super().__init__(event_shape, diff_eq, **kwargs) diff --git a/torchflows/bijections/continuous/rnode.py b/torchflows/bijections/continuous/rnode.py index 3ba06ba..d7a92f5 100644 --- a/torchflows/bijections/continuous/rnode.py +++ b/torchflows/bijections/continuous/rnode.py @@ -2,7 +2,12 @@ import torch -from torchflows.bijections.continuous.base import ApproximateContinuousBijection, create_nn, RegularizedApproximateODEFunction +from torchflows.bijections.continuous.base import ( + ApproximateContinuousBijection, + create_nn, + RegularizedApproximateODEFunction, + create_cnn +) # https://github.com/cfinlay/ffjord-rnode/blob/master/train.py @@ -12,7 +17,22 @@ class RNODE(ApproximateContinuousBijection): Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim, hidden_size=100, n_hidden_layers=1), regularization="sq_jac_norm") + diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim, hidden_size=100, n_hidden_layers=1), + regularization="sq_jac_norm") + super().__init__(event_shape, diff_eq, **kwargs) + + +class ConvolutionalRNODE(ApproximateContinuousBijection): + """Convolutional variant of the RNODE architecture + + Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + if len(event_shape) != 3: + raise ValueError("Event shape must be of length 3 (channels, height, width).") + diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0]), regularization="sq_jac_norm") super().__init__(event_shape, diff_eq, **kwargs) From d19939b316aecc3946a0720b79cc19192125cba9 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 31 Aug 2024 18:18:39 +0200 Subject: [PATCH 02/25] Add more Glow architectures --- torchflows/architectures.py | 5 +- .../finite/multiscale/architectures.py | 51 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/torchflows/architectures.py b/torchflows/architectures.py index 7f19356..f667bb9 100644 --- a/torchflows/architectures.py +++ b/torchflows/architectures.py @@ -44,5 +44,8 @@ MultiscaleDeepSigmoid, MultiscaleDenseSigmoid, MultiscaleDeepDenseSigmoid, - AffineGlow + AffineGlow, + RQSGlow, + LRSGlow, + ShiftGlow ) diff --git a/torchflows/bijections/finite/multiscale/architectures.py b/torchflows/bijections/finite/multiscale/architectures.py index bee149c..90191cc 100644 --- a/torchflows/bijections/finite/multiscale/architectures.py +++ b/torchflows/bijections/finite/multiscale/architectures.py @@ -178,6 +178,23 @@ def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layer ) +class ShiftGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=Shift, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) + + class AffineGlow(MultiscaleBijection): def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): @@ -193,3 +210,37 @@ def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layer n_blocks=n_layers, **kwargs ) + + +class RQSGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=RationalQuadratic, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) + + +class LRSGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=LinearRational, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) From 898d3cae29a1e437b4b0d14ae9d48e4070520619 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 31 Aug 2024 18:33:05 +0200 Subject: [PATCH 03/25] Fix convolutional continuous NF shapes --- torchflows/bijections/continuous/base.py | 39 ++++++++++++++-------- torchflows/bijections/continuous/ddnf.py | 3 +- torchflows/bijections/continuous/ffjord.py | 3 +- torchflows/bijections/continuous/layers.py | 33 ++++++++++++------ torchflows/bijections/continuous/rnode.py | 3 +- 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/torchflows/bijections/continuous/base.py b/torchflows/bijections/continuous/base.py index 4b795c7..b840a8f 100644 --- a/torchflows/bijections/continuous/base.py +++ b/torchflows/bijections/continuous/base.py @@ -56,32 +56,45 @@ def divergence_approx_extended(f, y, e: Union[torch.Tensor, Tuple[torch.Tensor]] return approx_tr_dzdx, N -def create_nn_time_independent(event_size: int, hidden_size: int = 30, n_hidden_layers: int = 2): +def create_nn_time_independent(event_shape: Union[Tuple[int, ...], torch.Size], + hidden_size: int = 30, + n_hidden_layers: int = 2): + event_size = int(torch.prod(torch.as_tensor(event_shape))) + + if hidden_size is None: + hidden_size = max(4, int(3 * math.log(event_size))) + hidden_shape = (hidden_size,) + assert n_hidden_layers >= 0 if n_hidden_layers == 0: - layers = [diff_eq_layers.IgnoreLinear(event_size, event_size)] + layers = [diff_eq_layers.IgnoreLinear(event_shape, event_shape)] else: layers = [ - diff_eq_layers.IgnoreLinear(event_size, hidden_size), - *[diff_eq_layers.IgnoreLinear(hidden_size, hidden_size) for _ in range(n_hidden_layers)], - diff_eq_layers.IgnoreLinear(hidden_size, event_size) + diff_eq_layers.IgnoreLinear(event_shape, hidden_shape), + *[diff_eq_layers.IgnoreLinear(hidden_shape, hidden_shape) for _ in range(n_hidden_layers)], + diff_eq_layers.IgnoreLinear(hidden_shape, event_shape) ] return TimeDerivativeDNN(layers) -def create_nn(event_size: int, hidden_size: int = None, n_hidden_layers: int = 2): +def create_nn(event_shape: Union[Tuple[int, ...], torch.Size], + hidden_size: int = None, + n_hidden_layers: int = 2): + event_size = int(torch.prod(torch.as_tensor(event_shape))) + if hidden_size is None: hidden_size = max(4, int(3 * math.log(event_size))) + hidden_shape = (hidden_size,) assert n_hidden_layers >= 0 if n_hidden_layers == 0: - layers = [diff_eq_layers.ConcatLinear(event_size, event_size)] + layers = [diff_eq_layers.ConcatLinear(event_shape, event_shape)] else: layers = [ - diff_eq_layers.ConcatLinear(event_size, hidden_size), - *[diff_eq_layers.ConcatLinear(hidden_size, hidden_size) for _ in range(n_hidden_layers)], - diff_eq_layers.ConcatLinear(hidden_size, event_size) + diff_eq_layers.ConcatLinear(event_shape, hidden_shape), + *[diff_eq_layers.ConcatLinear(hidden_shape, hidden_shape) for _ in range(n_hidden_layers)], + diff_eq_layers.ConcatLinear(hidden_shape, event_shape) ] return TimeDerivativeDNN(layers) @@ -442,7 +455,7 @@ def inverse(self, # Flatten everything to facilitate computations batch_shape = get_batch_shape(z, self.event_shape) batch_size = int(torch.prod(torch.as_tensor(batch_shape))) - z_flat = flatten_batch(flatten_event(z, self.event_shape), batch_shape) + z_flat = flatten_batch(z, batch_shape) if integration_times is None: integration_times = self.make_integrations_times(z_flat) @@ -450,7 +463,7 @@ def inverse(self, # Refresh odefunc statistics self.f.before_odeint(noise=noise) - log_det_initial = torch.zeros(size=(batch_size, 1)).to(z_flat) + log_det_initial = torch.zeros(size=(batch_size, *([1] * len(self.event_shape)))).to(z_flat) state_t = odeint( self.f, (z_flat, log_det_initial), @@ -466,7 +479,7 @@ def inverse(self, z_final_flat, log_det_final_flat = state_t[:2] # Reshape back to original shape - x = unflatten_event(unflatten_batch(z_final_flat, batch_shape), self.event_shape) + x = unflatten_batch(z_final_flat, batch_shape) log_det = log_det_final_flat.view(*batch_shape) return x, log_det diff --git a/torchflows/bijections/continuous/ddnf.py b/torchflows/bijections/continuous/ddnf.py index 2f60630..3fe235b 100644 --- a/torchflows/bijections/continuous/ddnf.py +++ b/torchflows/bijections/continuous/ddnf.py @@ -28,8 +28,7 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int :param event_shape: shape of the event tensor. :param n_steps: parameter T in the paper, i.e. the number of ResNet cells. """ - n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(n_dim)) + diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(event_shape)) self.n_steps = n_steps super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/ffjord.py b/torchflows/bijections/continuous/ffjord.py index 4f77f28..a3aed56 100644 --- a/torchflows/bijections/continuous/ffjord.py +++ b/torchflows/bijections/continuous/ffjord.py @@ -19,8 +19,7 @@ class FFJORD(ApproximateContinuousBijection): """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim)) + diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape)) super().__init__(event_shape, diff_eq, **kwargs) diff --git a/torchflows/bijections/continuous/layers.py b/torchflows/bijections/continuous/layers.py index 48cff63..aa6dd91 100644 --- a/torchflows/bijections/continuous/layers.py +++ b/torchflows/bijections/continuous/layers.py @@ -57,12 +57,18 @@ class IgnoreLinear(DiffEqLayer): Apply y = A @ x + b without any time information. """ - def __init__(self, dim_in, dim_out): + def __init__(self, input_shape, output_shape): super().__init__() - self._layer = nn.Linear(dim_in, dim_out) + self.input_shape = input_shape + self.input_size = int(torch.prod(torch.as_tensor(input_shape))) + self.output_shape = output_shape + self.output_size = int(torch.prod(torch.as_tensor(output_shape))) + self._layer = nn.Linear(self.input_size, self.output_size) def forward(self, t, x): - return self._layer(x) + x_flat = x.view(-1, self.input_size) + out_flat = self._layer(x_flat) + return out_flat.view(-1, *self.output_shape) class ConcatLinear(DiffEqLayer): @@ -70,14 +76,21 @@ class ConcatLinear(DiffEqLayer): Apply y = A @ [t; x] + b """ - def __init__(self, dim_in, dim_out): + def __init__(self, input_shape, output_shape): super().__init__() - self._layer = nn.Linear(dim_in + 1, dim_out) + self.input_shape = input_shape + self.input_size = int(torch.prod(torch.as_tensor(input_shape))) + self.output_shape = output_shape + self.output_size = int(torch.prod(torch.as_tensor(output_shape))) + self._layer = nn.Linear(self.input_size + 1, self.output_size) def forward(self, t, x): - tt = torch.ones_like(x[:, :1]) * t - ttx = torch.cat([tt, x], 1) - return self._layer(ttx) + # x.shape = (b, *input_shape) + x_flat = x.view(-1, self.input_size) + tt_flat = torch.ones_like(x_flat[:, :1]) * t + ttx_flat = torch.cat([tt_flat, x_flat], 1) + out_flat = self._layer(ttx_flat) + return out_flat.view(-1, *self.output_shape) class ConcatLinear_v2(DiffEqLayer): @@ -169,7 +182,7 @@ class IgnoreConv2d(DiffEqLayer): Apply y = Conv2d(x, W, b) """ - def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): + def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=1, dilation=1, groups=1, bias=True, transpose=False): super().__init__() module = nn.ConvTranspose2d if transpose else nn.Conv2d self._layer = module( @@ -204,7 +217,7 @@ class ConcatConv2d(DiffEqLayer): Apply y(t) = Conv2d(concatenate([x, t]), W, b) """ - def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): + def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=1, dilation=1, groups=1, bias=True, transpose=False): super().__init__() module = nn.ConvTranspose2d if transpose else nn.Conv2d self._layer = module( diff --git a/torchflows/bijections/continuous/rnode.py b/torchflows/bijections/continuous/rnode.py index d7a92f5..f1213ac 100644 --- a/torchflows/bijections/continuous/rnode.py +++ b/torchflows/bijections/continuous/rnode.py @@ -19,8 +19,7 @@ class RNODE(ApproximateContinuousBijection): """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim, hidden_size=100, n_hidden_layers=1), + diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape, hidden_size=100, n_hidden_layers=1), regularization="sq_jac_norm") super().__init__(event_shape, diff_eq, **kwargs) From e68af96967eb0345df7cb6a99b313ce9b2bae86e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 31 Aug 2024 18:34:47 +0200 Subject: [PATCH 04/25] Create Glow tests and convolutional continuous NF tests --- test/test_convolutional_architectures.py | 73 ++++++++++++++++++++++++ test/test_multiscale_bijections.py | 48 ---------------- 2 files changed, 73 insertions(+), 48 deletions(-) create mode 100644 test/test_convolutional_architectures.py delete mode 100644 test/test_multiscale_bijections.py diff --git a/test/test_convolutional_architectures.py b/test/test_convolutional_architectures.py new file mode 100644 index 0000000..166765f --- /dev/null +++ b/test/test_convolutional_architectures.py @@ -0,0 +1,73 @@ +from torchflows.architectures import ( + MultiscaleNICE, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleRealNVP, + AffineGlow, + RQSGlow, + LRSGlow, + ShiftGlow, + ConvolutionalRNODE, + ConvolutionalDeepDiffeomorphicBijection, + ConvolutionalFFJORD +) +import torch +import pytest +from test.constants import __test_constants + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP, + AffineGlow, + RQSGlow, + LRSGlow, + ShiftGlow +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_autoregressive(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=__test_constants['data_atol']) + assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2 + + +@pytest.mark.skip('Unsupported/failing') +@pytest.mark.parametrize('architecture_class', [ + ConvolutionalRNODE, + ConvolutionalDeepDiffeomorphicBijection, + ConvolutionalFFJORD +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_continuous(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert x.shape == xr.shape + assert ldf.shape == ldi.shape + assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x-xr).abs().max()}"' + assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2 + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP, + AffineGlow, + RQSGlow, + LRSGlow, + ShiftGlow +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=3) diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py deleted file mode 100644 index cfd4493..0000000 --- a/test/test_multiscale_bijections.py +++ /dev/null @@ -1,48 +0,0 @@ -from torchflows.architectures import ( - MultiscaleNICE, - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleRealNVP -) -import torch -import pytest -from test.constants import __test_constants - - -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) -def test_basic(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - bijection = architecture_class(image_shape, n_layers=2) - z, ldf = bijection.forward(x) - xr, ldi = bijection.inverse(z) - assert torch.allclose(x, xr, atol=__test_constants['data_atol']) - assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2 - - -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) -def test_too_small_image(architecture_class, image_shape): - torch.manual_seed(0) - with pytest.raises(ValueError): - bijection = architecture_class(image_shape, n_layers=3) - - - - - - - - - From 2e9a2ef5129ca88029d638ffb30f7bc7de43de6b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 00:01:34 +0200 Subject: [PATCH 05/25] Simplify autoregressive architectures --- .../finite/autoregressive/architectures.py | 340 ++++++------------ 1 file changed, 110 insertions(+), 230 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 1aba1e7..962a25d 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -1,4 +1,6 @@ -from typing import Tuple, List, Type, Union +from typing import Type, Union, Tuple + +import torch from torchflows.bijections.finite.autoregressive.layers import ( ShiftCoupling, @@ -26,392 +28,270 @@ from torchflows.bijections.finite.linear import ReversePermutation -def make_basic_layers(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - """ - Returns a list of bijections for transformations of vectors. - """ - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list, **kwargs)) +class AutoregressiveArchitecture(BijectiveComposition): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + base_bijection: Type[ + Union[ + CouplingBijection, + MaskedAutoregressiveBijection, + InverseMaskedAutoregressiveBijection + ] + ], + n_layers: int = 2, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + if 'edge_list' not in kwargs or kwargs['edge_list'] is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(base_bijection(event_shape=event_shape, **kwargs)) + bijections.append(ActNorm(event_shape=event_shape)) + bijections.append(ElementwiseAffine(event_shape=event_shape)) bijections.append(ActNorm(event_shape=event_shape)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) - bijections.append(ActNorm(event_shape=event_shape)) - return bijections + super().__init__(event_shape, bijections) -class NICE(BijectiveComposition): +class NICE(AutoregressiveArchitecture): """Nonlinear independent components estimation (NICE) architecture. Reference: Dinh et al. "NICE: Non-linear Independent Components Estimation" (2015); https://arxiv.org/abs/1410.8516. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=ShiftCoupling, **kwargs) -class RealNVP(BijectiveComposition): +class RealNVP(AutoregressiveArchitecture): """Real non-volume-preserving (Real NVP) architecture. Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=AffineCoupling, **kwargs) -class InverseRealNVP(BijectiveComposition): +class InverseRealNVP(AutoregressiveArchitecture): """Inverse of the Real NVP architecture. Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=InverseAffineCoupling, **kwargs) -class MAF(BijectiveComposition): +class MAF(AutoregressiveArchitecture): """Masked autoregressive flow (MAF) architecture. Reference: Papamakarios et al. "Masked Autoregressive Flow for Density Estimation" (2018); https://arxiv.org/abs/1705.07057. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(AffineForwardMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=AffineForwardMaskedAutoregressive, **kwargs) -class IAF(BijectiveComposition): +class IAF(AutoregressiveArchitecture): """Inverse autoregressive flow (IAF) architecture. Reference: Kingma et al. "Improving Variational Inference with Inverse Autoregressive Flow" (2017); https://arxiv.org/abs/1606.04934. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(AffineInverseMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=AffineInverseMaskedAutoregressive, **kwargs) -class CouplingRQNSF(BijectiveComposition): +class CouplingRQNSF(AutoregressiveArchitecture): """Coupling rational quadratic neural spline flow (C-RQNSF) architecture. Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=RQSCoupling, **kwargs) -class MaskedAutoregressiveRQNSF(BijectiveComposition): +class MaskedAutoregressiveRQNSF(AutoregressiveArchitecture): """Masked autoregressive rational quadratic neural spline flow (MA-RQNSF) architecture. Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(RQSForwardMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=RQSForwardMaskedAutoregressive, **kwargs) -class CouplingLRS(BijectiveComposition): +class CouplingLRS(AutoregressiveArchitecture): """Coupling linear rational spline (C-LRS) architecture. Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=LRSCoupling, **kwargs) -class MaskedAutoregressiveLRS(BijectiveComposition): +class MaskedAutoregressiveLRS(AutoregressiveArchitecture): """Masked autoregressive linear rational spline (MA-LRS) architecture. Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(LRSForwardMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=LRSForwardMaskedAutoregressive, **kwargs) -class InverseAutoregressiveRQNSF(BijectiveComposition): +class InverseAutoregressiveRQNSF(AutoregressiveArchitecture): """Inverse autoregressive rational quadratic neural spline flow (IA-RQNSF) architecture. Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(RQSInverseMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=RQSInverseMaskedAutoregressive, **kwargs) -class InverseAutoregressiveLRS(BijectiveComposition): +class InverseAutoregressiveLRS(AutoregressiveArchitecture): """Inverse autoregressive linear rational spline (MA-LRS) architecture. Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, event_shape, n_layers: int = 2, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(LRSInverseMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=LRSInverseMaskedAutoregressive, **kwargs) -class CouplingDeepSF(BijectiveComposition): +class CouplingDeepSF(AutoregressiveArchitecture): """Coupling deep sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(DeepSigmoidalCoupling, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=DeepSigmoidalCoupling, **kwargs) -class InverseAutoregressiveDeepSF(BijectiveComposition): +class InverseAutoregressiveDeepSF(AutoregressiveArchitecture): """Inverse autoregressive deep sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(DeepSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=DeepSigmoidalInverseMaskedAutoregressive, **kwargs) -class MaskedAutoregressiveDeepSF(BijectiveComposition): +class MaskedAutoregressiveDeepSF(AutoregressiveArchitecture): """Masked autoregressive deep sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(DeepSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, base_bijection=DeepSigmoidalForwardMaskedAutoregressive, **kwargs) -class CouplingDenseSF(BijectiveComposition): +class CouplingDenseSF(AutoregressiveArchitecture): """Coupling dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DenseSigmoidalCoupling, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DenseSigmoidalCoupling, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class InverseAutoregressiveDenseSF(BijectiveComposition): +class InverseAutoregressiveDenseSF(AutoregressiveArchitecture): """Inverse autoregressive dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DenseSigmoidalInverseMaskedAutoregressive, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DenseSigmoidalInverseMaskedAutoregressive, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class MaskedAutoregressiveDenseSF(BijectiveComposition): +class MaskedAutoregressiveDenseSF(AutoregressiveArchitecture): """Masked autoregressive dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DenseSigmoidalForwardMaskedAutoregressive, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DenseSigmoidalForwardMaskedAutoregressive, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class CouplingDeepDenseSF(BijectiveComposition): +class CouplingDeepDenseSF(AutoregressiveArchitecture): """Coupling deep-dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DeepDenseSigmoidalCoupling, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DeepDenseSigmoidalCoupling, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class InverseAutoregressiveDeepDenseSF(BijectiveComposition): +class InverseAutoregressiveDeepDenseSF(AutoregressiveArchitecture): """Inverse autoregressive deep-dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DeepDenseSigmoidalInverseMaskedAutoregressive, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DeepDenseSigmoidalInverseMaskedAutoregressive, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class MaskedAutoregressiveDeepDenseSF(BijectiveComposition): +class MaskedAutoregressiveDeepDenseSF(AutoregressiveArchitecture): """Masked autoregressive deep-dense sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - percentage_global_parameters: float = 0.8, - **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers( - DeepDenseSigmoidalForwardMaskedAutoregressive, + def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + super().__init__( event_shape, - n_layers, - edge_list, - percentage_global_parameters=percentage_global_parameters + base_bijection=DeepDenseSigmoidalForwardMaskedAutoregressive, + percentage_global_parameters=percentage_global_parameters, + **kwargs ) - super().__init__(event_shape, bijections, **kwargs) -class UMNNMAF(BijectiveComposition): +class UMNNMAF(AutoregressiveArchitecture): """Unconstrained monotonic neural network masked autoregressive flow (UMNN-MAF) architecture. Reference: Wehenkel and Louppe "Unconstrained Monotonic Neural Networks" (2021); https://arxiv.org/abs/1908.05164. """ - def __init__(self, event_shape, n_layers: int = 1, **kwargs): - if isinstance(event_shape, int): - event_shape = (event_shape,) - bijections = make_basic_layers(UMNNMaskedAutoregressive, event_shape, n_layers) - super().__init__(event_shape, bijections, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__( + event_shape, + base_bijection=DeepDenseSigmoidalForwardMaskedAutoregressive, + **kwargs + ) From c5ad1b0aa2756b9eaf341bbc387065bee3402e4b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 00:02:05 +0200 Subject: [PATCH 06/25] Fix UMNN MAF --- torchflows/bijections/finite/autoregressive/architectures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 962a25d..cb1d826 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -292,6 +292,6 @@ class UMNNMAF(AutoregressiveArchitecture): def __init__(self, event_shape, **kwargs): super().__init__( event_shape, - base_bijection=DeepDenseSigmoidalForwardMaskedAutoregressive, + base_bijection=UMNNMaskedAutoregressive, **kwargs ) From 71a92ed6c284ab879a6ed0be5ea0512b066d1bea Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 00:15:58 +0200 Subject: [PATCH 07/25] Add docstrings for class constructors --- .../finite/autoregressive/architectures.py | 177 +++++++++++++++--- 1 file changed, 156 insertions(+), 21 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index cb1d826..8975685 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -59,7 +59,12 @@ class NICE(AutoregressiveArchitecture): Reference: Dinh et al. "NICE: Non-linear Independent Components Estimation" (2015); https://arxiv.org/abs/1410.8516. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to ShiftCoupling. + """ super().__init__(event_shape, base_bijection=ShiftCoupling, **kwargs) @@ -69,7 +74,12 @@ class RealNVP(AutoregressiveArchitecture): Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to AffineCoupling. + """ super().__init__(event_shape, base_bijection=AffineCoupling, **kwargs) @@ -79,7 +89,12 @@ class InverseRealNVP(AutoregressiveArchitecture): Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to InverseAffineCoupling. + """ super().__init__(event_shape, base_bijection=InverseAffineCoupling, **kwargs) @@ -89,7 +104,12 @@ class MAF(AutoregressiveArchitecture): Reference: Papamakarios et al. "Masked Autoregressive Flow for Density Estimation" (2018); https://arxiv.org/abs/1705.07057. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to AffineForwardMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=AffineForwardMaskedAutoregressive, **kwargs) @@ -99,7 +119,12 @@ class IAF(AutoregressiveArchitecture): Reference: Kingma et al. "Improving Variational Inference with Inverse Autoregressive Flow" (2017); https://arxiv.org/abs/1606.04934. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to AffineInverseMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=AffineInverseMaskedAutoregressive, **kwargs) @@ -109,7 +134,12 @@ class CouplingRQNSF(AutoregressiveArchitecture): Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to RQSCoupling. + """ super().__init__(event_shape, base_bijection=RQSCoupling, **kwargs) @@ -119,7 +149,12 @@ class MaskedAutoregressiveRQNSF(AutoregressiveArchitecture): Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to RQSForwardMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=RQSForwardMaskedAutoregressive, **kwargs) @@ -129,7 +164,12 @@ class CouplingLRS(AutoregressiveArchitecture): Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to LRSCoupling. + """ super().__init__(event_shape, base_bijection=LRSCoupling, **kwargs) @@ -139,7 +179,12 @@ class MaskedAutoregressiveLRS(AutoregressiveArchitecture): Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to LRSForwardMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=LRSForwardMaskedAutoregressive, **kwargs) @@ -149,7 +194,12 @@ class InverseAutoregressiveRQNSF(AutoregressiveArchitecture): Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to RQSInverseMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=RQSInverseMaskedAutoregressive, **kwargs) @@ -159,7 +209,12 @@ class InverseAutoregressiveLRS(AutoregressiveArchitecture): Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to LRSInverseMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=LRSInverseMaskedAutoregressive, **kwargs) @@ -169,7 +224,12 @@ class CouplingDeepSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to DeepSigmoidalCoupling. + """ super().__init__(event_shape, base_bijection=DeepSigmoidalCoupling, **kwargs) @@ -179,7 +239,12 @@ class InverseAutoregressiveDeepSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to DeepSigmoidalInverseMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=DeepSigmoidalInverseMaskedAutoregressive, **kwargs) @@ -189,7 +254,12 @@ class MaskedAutoregressiveDeepSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to DeepSigmoidalForwardMaskedAutoregressive. + """ super().__init__(event_shape, base_bijection=DeepSigmoidalForwardMaskedAutoregressive, **kwargs) @@ -199,7 +269,17 @@ class CouplingDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DenseSigmoidalCoupling. + """ super().__init__( event_shape, base_bijection=DenseSigmoidalCoupling, @@ -214,7 +294,17 @@ class InverseAutoregressiveDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DenseSigmoidalInverseMaskedAutoregressive. + """ super().__init__( event_shape, base_bijection=DenseSigmoidalInverseMaskedAutoregressive, @@ -229,7 +319,17 @@ class MaskedAutoregressiveDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DenseSigmoidalForwardMaskedAutoregressive. + """ super().__init__( event_shape, base_bijection=DenseSigmoidalForwardMaskedAutoregressive, @@ -244,7 +344,17 @@ class CouplingDeepDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DeepDenseSigmoidalCoupling. + """ super().__init__( event_shape, base_bijection=DeepDenseSigmoidalCoupling, @@ -259,7 +369,17 @@ class InverseAutoregressiveDeepDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DeepDenseSigmoidalInverseMaskedAutoregressive. + """ super().__init__( event_shape, base_bijection=DeepDenseSigmoidalInverseMaskedAutoregressive, @@ -274,7 +394,17 @@ class MaskedAutoregressiveDeepDenseSF(AutoregressiveArchitecture): Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ - def __init__(self, event_shape, percentage_global_parameters: float = 0.8, **kwargs): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size, int], + percentage_global_parameters: float = 0.8, + **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being + predicted from the conditioner neural network. + :param kwargs: keyword arguments to DeepDenseSigmoidalForwardMaskedAutoregressive. + """ super().__init__( event_shape, base_bijection=DeepDenseSigmoidalForwardMaskedAutoregressive, @@ -289,7 +419,12 @@ class UMNNMAF(AutoregressiveArchitecture): Reference: Wehenkel and Louppe "Unconstrained Monotonic Neural Networks" (2021); https://arxiv.org/abs/1908.05164. """ - def __init__(self, event_shape, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwargs): + """ + + :param event_shape: shape of the event tensor. + :param kwargs: keyword arguments to UMNNMaskedAutoregressive. + """ super().__init__( event_shape, base_bijection=UMNNMaskedAutoregressive, From 42c5a06283b87f9390d1a037f44b993e0d0125fa Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 00:44:27 +0200 Subject: [PATCH 08/25] Add docstrings --- .../finite/autoregressive/layers.py | 378 ++++++++++++++---- 1 file changed, 291 insertions(+), 87 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index 727ca5d..cc034ee 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -1,4 +1,4 @@ -from typing import Tuple, List +from typing import Tuple, List, Union import torch @@ -22,18 +22,33 @@ class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to Affine. + """ transformer = Affine(event_shape, **kwargs) super().__init__(transformer) class ElementwiseInverseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to InverseAffine. + """ transformer = InverseAffine(event_shape, **kwargs) super().__init__(transformer) class ActNorm(ElementwiseInverseAffine): def __init__(self, event_shape, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to ElementwiseInverseAffine. + """ super().__init__(event_shape, **kwargs) self.first_training_batch_pass: bool = True self.value.requires_grad_(False) @@ -61,34 +76,54 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseScale(ElementwiseBijection): def __init__(self, event_shape, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to Scale. + """ transformer = Scale(event_shape, **kwargs) super().__init__(transformer) class ElementwiseShift(ElementwiseBijection): def __init__(self, event_shape): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + """ transformer = Shift(event_shape) super().__init__(transformer) class ElementwiseRQSpline(ElementwiseBijection): def __init__(self, event_shape, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to RationalQuadratic. + """ transformer = RationalQuadratic(event_shape, **kwargs) super().__init__(transformer) class AffineCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - edge_list: List[Tuple[int, int]] = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if event_shape == (1,): raise ValueError if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -101,16 +136,22 @@ def __init__(self, class InverseAffineCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - edge_list: List[Tuple[int, int]] = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if event_shape == (1,): raise ValueError if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -123,14 +164,20 @@ def __init__(self, class ShiftCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - edge_list: List[Tuple[int, int]] = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -143,16 +190,23 @@ def __init__(self, class LRSCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, - edge_list: List[Tuple[int, int]] = None, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ assert n_bins >= 1 if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -165,15 +219,22 @@ def __init__(self, class RQSCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, - edge_list: List[Tuple[int, int]] = None, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -186,18 +247,25 @@ def __init__(self, class DeepSigmoidalCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DeepSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -212,45 +280,68 @@ def __init__(self, class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DeepSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DeepSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class DenseSigmoidalCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_dense_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - coupling_kwargs: dict = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_layers: int = 2, percentage_global_parameters: float = 0.8, + coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_layers: number of transformer layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DenseSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_dense_layers=n_dense_layers + n_dense_layers=n_transformer_layers ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -268,14 +359,23 @@ def __init__(self, class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_dense_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_layers: int = 2, percentage_global_parameters: float = 0.8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_layers: number of transformer layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), - n_dense_layers=n_dense_layers + n_dense_layers=n_transformer_layers ) super().__init__( event_shape, @@ -290,14 +390,23 @@ def __init__(self, class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_dense_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_layers: int = 2, percentage_global_parameters: float = 0.8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_layers: number of transformer layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), - n_dense_layers=n_dense_layers + n_dense_layers=n_transformer_layers ) super().__init__( event_shape, @@ -312,19 +421,28 @@ def __init__(self, class DeepDenseSigmoidalCoupling(CouplingBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - coupling_kwargs: dict = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, percentage_global_parameters: float = 0.8, + coupling_kwargs: dict = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param kwargs: keyword arguments to CouplingBijection. + """ if coupling_kwargs is None: coupling_kwargs = dict() - coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DeepDenseSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -342,14 +460,23 @@ def __init__(self, class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, percentage_global_parameters: float = 0.8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) super().__init__( event_shape, @@ -364,14 +491,23 @@ def __init__(self, class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = 2, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = 2, percentage_global_parameters: float = 0.8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of + being predicted from the conditioner neural network. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_hidden_layers + n_hidden_layers=n_transformer_hidden_layers ) super().__init__( event_shape, @@ -385,69 +521,122 @@ def __init__(self, class LinearAffineCoupling(AffineCoupling): - def __init__(self, event_shape: torch.Size, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to AffineCoupling. + """ super().__init__(event_shape, **kwargs, n_layers=1) class LinearRQSCoupling(RQSCoupling): - def __init__(self, event_shape: torch.Size, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to RQSCoupling. + """ super().__init__(event_shape, **kwargs, n_layers=1) class LinearLRSCoupling(LRSCoupling): - def __init__(self, event_shape: torch.Size, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to LRSCoupling. + """ super().__init__(event_shape, **kwargs, n_layers=1) class LinearShiftCoupling(ShiftCoupling): - def __init__(self, event_shape: torch.Size, **kwargs): + def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to ShiftCoupling. + """ super().__init__(event_shape, **kwargs, n_layers=1) class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = Affine(event_shape=event_shape) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = invert(Affine(event_shape=event_shape)) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ assert n_bins >= 1 transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -455,24 +644,39 @@ def __init__(self, class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, n_bins: int = 8, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_bins: number of spline bins. + :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, - event_shape: torch.Size, - context_shape: torch.Size = None, - n_hidden_layers: int = None, - hidden_dim: int = None, + event_shape: Union[Tuple[int, ...], torch.Size], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + n_transformer_hidden_layers: int = None, + transformer_hidden_size: int = None, **kwargs): + """ + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. + :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param int transformer_hidden_size: transformer hidden layer size. + :param kwargs: keyword arguments to MaskedAutoregressiveBijection. + """ transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, - n_hidden_layers=n_hidden_layers, - hidden_dim=hidden_dim + n_hidden_layers=n_transformer_hidden_layers, + hidden_dim=transformer_hidden_size ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) From 028ce3e5ecd704daedfaebbcfe4561f37a688a6c Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 01:07:16 +0200 Subject: [PATCH 09/25] Simplify autoregressive layer transformer kwargs --- .../finite/autoregressive/layers.py | 149 ++++++++++-------- 1 file changed, 81 insertions(+), 68 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index cc034ee..aea947c 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Union +from typing import Tuple, Union import torch @@ -111,6 +111,7 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): """ @@ -121,10 +122,10 @@ def __init__(self, """ if event_shape == (1,): raise ValueError - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), @@ -139,6 +140,7 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): """ @@ -149,10 +151,10 @@ def __init__(self, """ if event_shape == (1,): raise ValueError - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)))) + transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs)) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), @@ -175,8 +177,7 @@ def __init__(self, :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( @@ -192,8 +193,8 @@ class LRSCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): """ @@ -203,11 +204,10 @@ def __init__(self, :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - assert n_bins >= 1 - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) + transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), @@ -221,21 +221,20 @@ class RQSCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) + transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), @@ -249,23 +248,22 @@ class DeepSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DeepSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -282,18 +280,18 @@ class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBiject def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DeepSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -302,18 +300,18 @@ class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DeepSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -322,26 +320,25 @@ class DenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_layers: int = 2, - percentage_global_parameters: float = 0.8, coupling_kwargs: dict = None, + transformer_kwargs: dict = None, + percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_layers: number of transformer layers. :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being predicted from the conditioner neural network. :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DenseSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_dense_layers=n_transformer_layers + **transformer_kwargs ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -361,21 +358,21 @@ class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijec def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_layers: int = 2, + transformer_kwargs: dict = None, percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_layers: number of transformer layers. :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being predicted from the conditioner neural network. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), - n_dense_layers=n_transformer_layers + **transformer_kwargs ) super().__init__( event_shape, @@ -392,7 +389,7 @@ class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_layers: int = 2, + transformer_kwargs: dict = None, percentage_global_parameters: float = 0.8, **kwargs): """ @@ -404,9 +401,10 @@ def __init__(self, being predicted from the conditioner neural network. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), - n_dense_layers=n_transformer_layers + **transformer_kwargs ) super().__init__( event_shape, @@ -423,26 +421,30 @@ class DeepDenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, - percentage_global_parameters: float = 0.8, + transformer_kwargs: dict = None, coupling_kwargs: dict = None, + percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param int n_transformer_dense_layers: number of transformer dense layers. + :param int transformer_hidden_size: transformer hidden layer size. :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being predicted from the conditioner neural network. :param dict coupling_kwargs: keyword arguments to `make_coupling`. + :param dict transformer_kwargs: keyword arguments to DeepDenseSigmoid. :param kwargs: keyword arguments to CouplingBijection. """ - if coupling_kwargs is None: - coupling_kwargs = dict() + coupling_kwargs = coupling_kwargs or {} + transformer_kwargs = transformer_kwargs or {} + coupling = make_coupling(event_shape, **coupling_kwargs) transformer = DeepDenseSigmoid( event_shape=torch.Size((coupling.target_event_size,)), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] @@ -462,7 +464,7 @@ class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveB def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, + transformer_kwargs: dict = None, percentage_global_parameters: float = 0.8, **kwargs): """ @@ -470,13 +472,16 @@ def __init__(self, :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param int n_transformer_dense_layers: number of transformer dense layers. + :param int transformer_hidden_size: transformer hidden layer size. :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being predicted from the conditioner neural network. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) super().__init__( event_shape, @@ -493,7 +498,7 @@ class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijectio def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = 2, + transformer_kwargs: dict = None, percentage_global_parameters: float = 0.8, **kwargs): """ @@ -501,13 +506,16 @@ def __init__(self, :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param int n_transformer_hidden_layers: number of transformer hidden layers. + :param int n_transformer_dense_layers: number of transformer dense layers. + :param int transformer_hidden_size: transformer hidden layer size. :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of being predicted from the conditioner neural network. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), - n_hidden_layers=n_transformer_hidden_layers + **transformer_kwargs ) super().__init__( event_shape, @@ -564,6 +572,7 @@ class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, + transformer_kwargs: dict = None, **kwargs): """ @@ -571,7 +580,8 @@ def __init__(self, :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer: ScalarTransformer = Affine(event_shape=event_shape) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = Affine(event_shape=event_shape, **transformer_kwargs) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -579,7 +589,7 @@ class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, + transformer_kwargs: dict = None, **kwargs): """ @@ -588,7 +598,8 @@ def __init__(self, :param int n_bins: number of spline bins. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -596,7 +607,7 @@ class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, + transformer_kwargs: dict = None, **kwargs): """ @@ -605,7 +616,8 @@ def __init__(self, :param int n_bins: number of spline bins. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -613,6 +625,7 @@ class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, + transformer_kwargs: dict = None, **kwargs): """ @@ -620,7 +633,8 @@ def __init__(self, :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer: ScalarTransformer = invert(Affine(event_shape=event_shape)) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = invert(Affine(event_shape=event_shape, **transformer_kwargs)) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -628,7 +642,7 @@ class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, + transformer_kwargs: dict = None, **kwargs): """ @@ -637,8 +651,8 @@ def __init__(self, :param int n_bins: number of spline bins. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - assert n_bins >= 1 - transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -646,16 +660,17 @@ class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_bins: int = 8, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. + :param dict transformer_kwargs: keyword arguments to LinearRational. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) + transformer_kwargs = transformer_kwargs or {} + transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) @@ -663,20 +678,18 @@ class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], context_shape: Union[Tuple[int, ...], torch.Size] = None, - n_transformer_hidden_layers: int = None, - transformer_hidden_size: int = None, + transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. - :param int transformer_hidden_size: transformer hidden layer size. + :param dict transformer_kwargs: keyword arguments to UnconstrainedMonotonicNeuralNetwork. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ + transformer_kwargs = transformer_kwargs or {} transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, - n_hidden_layers=n_transformer_hidden_layers, - hidden_dim=transformer_hidden_size + **transformer_kwargs ) super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) From 1888a442a2a5269ed0f799af0ccf21d0603fb08a Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 02:12:58 +0200 Subject: [PATCH 10/25] Major coupling NF cleanup/refactor --- test/test_channel_wise_coupling.py | 4 +- .../conditioning/coupling_masks.py | 8 + .../autoregressive/conditioning/transforms.py | 6 +- .../finite/autoregressive/layers.py | 175 ++---------------- .../finite/autoregressive/layers_base.py | 48 ++++- .../autoregressive/transformers/base.py | 2 +- .../bijections/finite/multiscale/base.py | 68 ++----- .../finite/multiscale/conditioning/classic.py | 7 +- .../finite/multiscale/conditioning/resnet.py | 1 - .../bijections/finite/multiscale/coupling.py | 4 +- 10 files changed, 88 insertions(+), 235 deletions(-) diff --git a/test/test_channel_wise_coupling.py b/test/test_channel_wise_coupling.py index 4e7ea05..159c7c3 100644 --- a/test/test_channel_wise_coupling.py +++ b/test/test_channel_wise_coupling.py @@ -8,7 +8,7 @@ def test_partition_shapes_1(): image_shape = (3, 4, 4) coupling = ChannelWiseHalfSplit(image_shape, invert=True) assert coupling.constant_shape == (1, 4, 4) - assert coupling.transformed_shape == (2, 4, 4) + assert coupling.target_shape == (2, 4, 4) def test_partition_shapes_2(): @@ -16,4 +16,4 @@ def test_partition_shapes_2(): image_shape = (3, 16, 16) coupling = ChannelWiseHalfSplit(image_shape, invert=True) assert coupling.constant_shape == (1, 16, 16) - assert coupling.transformed_shape == (2, 16, 16) + assert coupling.target_shape == (2, 16, 16) diff --git a/torchflows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/torchflows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 08b3d61..41592c4 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -34,10 +34,18 @@ def ignored_event_size(self): def source_event_size(self): return int(torch.sum(self.source_mask)) + @property + def constant_shape(self) -> Tuple[int, ...]: + return (self.source_event_size,) + @property def target_event_size(self): return int(torch.sum(self.target_mask)) + @property + def target_shape(self) -> Tuple[int, ...]: + return (self.target_event_size,) + class Coupling(PartialCoupling): """ diff --git a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py index 7ce5d28..e34d947 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py @@ -272,9 +272,9 @@ def __init__(self, *args, **kwargs): class FeedForward(TensorConditionerTransform): def __init__(self, - input_event_shape: torch.Size, - parameter_shape: torch.Size, - context_shape: torch.Size = None, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + parameter_shape: torch.Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, n_layers: int = 2, nonlinearity: Type[nn.Module] = nn.Tanh, diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index aea947c..22e724c 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -109,171 +109,77 @@ def __init__(self, event_shape, **kwargs): class AffineCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ if event_shape == (1,): raise ValueError - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = Affine(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, Affine, **kwargs) class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ if event_shape == (1,): raise ValueError - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs)) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, InverseAffine, **kwargs) class ShiftCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, Shift, **kwargs) class LRSCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, LinearRational, **kwargs) class RQSCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), **transformer_kwargs) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, RationalQuadratic, **kwargs) class DeepSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = DeepSigmoid( - event_shape=torch.Size((coupling.target_event_size,)), - **transformer_kwargs - ) - # Parameter order: [c1, c2, c3, c4, ..., ck] for all components - # Each component has parameter order [a_unc, b, w_unc] - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) + super().__init__(event_shape, DeepSigmoid, **kwargs) class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -319,39 +225,17 @@ def __init__(self, class DenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - coupling_kwargs: dict = None, - transformer_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = DenseSigmoid( - event_shape=torch.Size((coupling.target_event_size,)), - **transformer_kwargs - ) - # Parameter order: [c1, c2, c3, c4, ..., ck] for all components - # Each component has parameter order [a_unc, b, w_unc] - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) - super().__init__(transformer, coupling, conditioner_transform) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DenseSigmoid, **kwargs) class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -420,44 +304,17 @@ def __init__(self, class DeepDenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, - coupling_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. - :param int n_transformer_dense_layers: number of transformer dense layers. - :param int transformer_hidden_size: transformer hidden layer size. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. - :param dict coupling_kwargs: keyword arguments to `make_coupling`. - :param dict transformer_kwargs: keyword arguments to DeepDenseSigmoid. :param kwargs: keyword arguments to CouplingBijection. """ - coupling_kwargs = coupling_kwargs or {} - transformer_kwargs = transformer_kwargs or {} - - coupling = make_coupling(event_shape, **coupling_kwargs) - transformer = DeepDenseSigmoid( - event_shape=torch.Size((coupling.target_event_size,)), - **transformer_kwargs - ) - # Parameter order: [c1, c2, c3, c4, ..., ck] for all components - # Each component has parameter order [a_unc, b, w_unc] - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) - super().__init__(transformer, coupling, conditioner_transform) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DeepDenseSigmoid, **kwargs) class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index fe07d4c..a33da2a 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -4,8 +4,8 @@ import torch.nn as nn from torchflows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \ - MADE -from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling + MADE, FeedForward +from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling, make_coupling from torchflows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from torchflows.bijections.base import Bijection from torchflows.utils import flatten_event, unflatten_event, get_batch_shape @@ -54,21 +54,48 @@ class CouplingBijection(AutoregressiveBijection): """ def __init__(self, - transformer: TensorTransformer, - coupling: PartialCoupling, - conditioner_transform: ConditionerTransform, + event_shape: Union[Tuple[int, ...], torch.Size], + transformer_class: Type[TensorTransformer], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + coupling: PartialCoupling = None, + conditioner_transform_class: Type[ConditionerTransform] = FeedForward, + coupling_kwargs: dict = None, + conditioner_kwargs: dict = None, + transformer_kwargs: dict = None, **kwargs): - super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + coupling_kwargs = coupling_kwargs or {} + conditioner_kwargs = conditioner_kwargs or {} + transformer_kwargs = transformer_kwargs or {} + + if coupling is None: + coupling = make_coupling(event_shape, **coupling_kwargs) + + transformer = transformer_class( + event_shape=coupling.target_shape, + **transformer_kwargs + ) + + conditioner_transform = conditioner_transform_class( + input_event_shape=coupling.constant_shape, # (coupling.source_event_size,), + context_shape=context_shape, + parameter_shape=transformer.parameter_shape, + **conditioner_kwargs + ) + + super().__init__(event_shape, transformer, conditioner_transform, **kwargs) self.coupling = coupling def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: - return x[..., self.coupling.source_mask] + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.source_mask].view(*batch_shape, *self.coupling.constant_shape) def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: - return x[..., self.coupling.target_mask] + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.target_shape) def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): - x[..., self.coupling.target_mask] = x_transformed + batch_shape = get_batch_shape(x, self.event_shape) + x[..., self.coupling.target_mask] = x_transformed.reshape(*batch_shape, -1) def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): """ @@ -79,9 +106,10 @@ def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tenso :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). """ # Predict transformer parameters for output dimensions + batch_shape = get_batch_shape(x, self.event_shape) x_a = self.get_constant_part(x) # (*b, constant_event_size) h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) - return h_b + return h_b.view(*batch_shape, *self.transformer.parameter_shape) def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() diff --git a/torchflows/bijections/finite/autoregressive/transformers/base.py b/torchflows/bijections/finite/autoregressive/transformers/base.py index c6e389c..c0e6f3f 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/base.py +++ b/torchflows/bijections/finite/autoregressive/transformers/base.py @@ -16,7 +16,7 @@ class TensorTransformer(Bijection): individually. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape=event_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/torchflows/bijections/finite/multiscale/base.py b/torchflows/bijections/finite/multiscale/base.py index b88a9d1..40cb68a 100644 --- a/torchflows/bijections/finite/multiscale/base.py +++ b/torchflows/bijections/finite/multiscale/base.py @@ -18,59 +18,24 @@ class ConvolutionalCouplingBijection(CouplingBijection): def __init__(self, - transformer: TensorTransformer, + event_shape: Union[Tuple[int, ...], torch.Size], + transformer_class: Type[TensorTransformer], coupling: Union[Checkerboard, ChannelWiseHalfSplit], - conditioner='convnet', + conditioner: str = 'convnet', **kwargs): if conditioner == 'convnet': - conditioner_transform = ConvNetConditioner( - input_event_shape=coupling.constant_shape, - parameter_shape=transformer.parameter_shape, - **kwargs - ) + conditioner_transform_class = ConvNetConditioner elif conditioner == 'resnet': - conditioner_transform = ResNetConditioner( - input_event_shape=coupling.constant_shape, - parameter_shape=transformer.parameter_shape, - **kwargs - ) + conditioner_transform_class = ResNetConditioner else: raise ValueError(f'Unknown conditioner: {conditioner}') - super().__init__(transformer, coupling, conditioner_transform, **kwargs) - self.coupling = coupling - - def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: - """ - - :param x: tensor with shape (*b, channels, height, width). - :return: tensor with shape (*b, constant_channels, constant_height, constant_width). - """ - batch_shape = get_batch_shape(x, self.event_shape) - return x[..., self.coupling.source_mask].view(*batch_shape, *self.coupling.constant_shape) - - def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: - """ - - :param x: tensor with shape (*b, channels, height, width). - :return: tensor with shape (*b, transformed_channels, transformed_height, constant_width). - """ - batch_shape = get_batch_shape(x, self.event_shape) - return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) - - def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): - """ - - :param x: tensor with shape (*b, channels, height, width). - :param x_transformed: tensor with shape (*b, transformed_channels, transformed_height, transformed_width). - """ - batch_shape = get_batch_shape(x, self.event_shape) - x[..., self.coupling.target_mask] = x_transformed.reshape(*batch_shape, -1) - return x - - def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): - batch_shape = get_batch_shape(x, self.event_shape) - super_out = super().partition_and_predict_parameters(x, context) - return super_out.view(*batch_shape, *self.transformer.parameter_shape) + super().__init__( + event_shape=event_shape, + transformer_class=transformer_class, + coupling=coupling, + conditioner_transform_class=conditioner_transform_class, + **kwargs + ) class CheckerboardCoupling(ConvolutionalCouplingBijection): @@ -83,8 +48,7 @@ def __init__(self, event_shape, coupling_type='checkerboard' if not alternate else 'checkerboard_inverted', ) - transformer = transformer_class(event_shape=coupling.transformed_shape) - super().__init__(transformer, coupling, **kwargs) + super().__init__(event_shape, transformer_class, coupling, **kwargs) class NormalizedCheckerboardCoupling(BijectiveComposition): @@ -105,8 +69,7 @@ def __init__(self, event_shape, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted', ) - transformer = Invertible1x1ConvolutionTransformer(event_shape=coupling.transformed_shape) - super().__init__(transformer, coupling, **kwargs) + super().__init__(event_shape, Invertible1x1ConvolutionTransformer, coupling, **kwargs) class GlowCheckerboardCoupling(BijectiveComposition): @@ -129,8 +92,7 @@ def __init__(self, event_shape, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' ) - transformer = transformer_class(event_shape=coupling.transformed_shape) - super().__init__(transformer, coupling, **kwargs) + super().__init__(event_shape, transformer_class, coupling, **kwargs) class NormalizedChannelWiseCoupling(BijectiveComposition): diff --git a/torchflows/bijections/finite/multiscale/conditioning/classic.py b/torchflows/bijections/finite/multiscale/conditioning/classic.py index e5da0f6..edf81cf 100644 --- a/torchflows/bijections/finite/multiscale/conditioning/classic.py +++ b/torchflows/bijections/finite/multiscale/conditioning/classic.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Union import torch import torch.nn as nn @@ -124,13 +124,12 @@ def forward(self, x): class ConvNetConditioner(TensorConditionerTransform): def __init__(self, - input_event_shape: torch.Size, - parameter_shape: torch.Size, + input_event_shape: Union[Tuple[int, ...], torch.Size], + parameter_shape: Union[Tuple[int, ...], torch.Size], kernels: Tuple[int, ...] = None, **kwargs): super().__init__( input_event_shape=input_event_shape, - context_shape=None, parameter_shape=parameter_shape, output_lower_bound=-2.0, output_upper_bound=2.0, diff --git a/torchflows/bijections/finite/multiscale/conditioning/resnet.py b/torchflows/bijections/finite/multiscale/conditioning/resnet.py index d2e735b..ac474de 100644 --- a/torchflows/bijections/finite/multiscale/conditioning/resnet.py +++ b/torchflows/bijections/finite/multiscale/conditioning/resnet.py @@ -125,7 +125,6 @@ def __init__(self, **kwargs): super().__init__( input_event_shape=input_event_shape, - context_shape=None, parameter_shape=parameter_shape, output_lower_bound=-2.0, output_upper_bound=2.0, diff --git a/torchflows/bijections/finite/multiscale/coupling.py b/torchflows/bijections/finite/multiscale/coupling.py index 785ff2e..6b8fd87 100644 --- a/torchflows/bijections/finite/multiscale/coupling.py +++ b/torchflows/bijections/finite/multiscale/coupling.py @@ -27,7 +27,7 @@ def constant_shape(self): return n_channels, height // 2, width # rectangular shape @property - def transformed_shape(self): + def target_shape(self): return self.constant_shape @@ -58,7 +58,7 @@ def constant_shape(self): return n_channels // 2, height, width @property - def transformed_shape(self): + def target_shape(self): n_channels, height, width = self.event_shape return n_channels - n_channels // 2, height, width From bfdfbce39df244b8ab824da43c5889b629f56bb0 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 02:24:57 +0200 Subject: [PATCH 11/25] Masked/inverse autoregressive NF cleanup/refactor --- .../finite/autoregressive/layers.py | 188 +++--------------- .../finite/autoregressive/layers_base.py | 15 +- .../autoregressive/transformers/base.py | 4 +- 3 files changed, 41 insertions(+), 166 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index 22e724c..6e2133a 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -2,12 +2,9 @@ import torch -from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling from torchflows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift, InverseAffine -from torchflows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from torchflows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ UnconstrainedMonotonicNeuralNetwork from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational @@ -17,7 +14,6 @@ DenseSigmoid, DeepDenseSigmoid ) -from torchflows.bijections.base import invert class ElementwiseAffine(ElementwiseBijection): @@ -185,41 +181,25 @@ def __init__(self, class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DeepSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, DeepSigmoid, **kwargs) class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DeepSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, DeepSigmoid, **kwargs) class DenseSigmoidalCoupling(CouplingBijection): @@ -241,64 +221,33 @@ def __init__(self, class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DenseSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__( - event_shape, - context_shape, - transformer=transformer, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DenseSigmoid, **kwargs) class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_layers: number of transformer layers. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DenseSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__( - event_shape, - context_shape, - transformer=transformer, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DenseSigmoid, **kwargs) class DeepDenseSigmoidalCoupling(CouplingBijection): @@ -320,69 +269,33 @@ def __init__(self, class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. - :param int n_transformer_dense_layers: number of transformer dense layers. - :param int transformer_hidden_size: transformer hidden layer size. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DeepDenseSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__( - event_shape, - context_shape, - transformer=transformer, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DeepDenseSigmoid, **kwargs) class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, - percentage_global_parameters: float = 0.8, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_transformer_hidden_layers: number of transformer hidden layers. - :param int n_transformer_dense_layers: number of transformer dense layers. - :param int transformer_hidden_size: transformer hidden layer size. - :param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of - being predicted from the conditioner neural network. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = DeepDenseSigmoid( - event_shape=torch.Size(event_shape), - **transformer_kwargs - ) - super().__init__( - event_shape, - context_shape, - transformer=transformer, - **{ - **kwargs, - **dict(percentage_global_parameters=percentage_global_parameters) - } - ) + if 'conditioner_kwargs' not in kwargs: + kwargs['conditioner_kwargs'] = {} + if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']: + kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8 + super().__init__(event_shape, DeepDenseSigmoid, **kwargs) class LinearAffineCoupling(AffineCoupling): @@ -428,125 +341,82 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs): class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = Affine(event_shape=event_shape, **transformer_kwargs) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, Affine, **kwargs) class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, RationalQuadratic, **kwargs) class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, LinearRational, **kwargs) class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = invert(Affine(event_shape=event_shape, **transformer_kwargs)) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, InverseAffine, **kwargs) class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param int n_bins: number of spline bins. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, RationalQuadratic, **kwargs) class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict transformer_kwargs: keyword arguments to LinearRational. :param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, LinearRational, **kwargs) class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - context_shape: Union[Tuple[int, ...], torch.Size] = None, - transformer_kwargs: dict = None, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor. - :param dict transformer_kwargs: keyword arguments to UnconstrainedMonotonicNeuralNetwork. :param kwargs: keyword arguments to MaskedAutoregressiveBijection. """ - transformer_kwargs = transformer_kwargs or {} - transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( - event_shape=event_shape, - **transformer_kwargs - ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__(event_shape, UnconstrainedMonotonicNeuralNetwork, **kwargs) diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index a33da2a..4e7cdc4 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -144,18 +144,23 @@ class MaskedAutoregressiveBijection(AutoregressiveBijection): """ def __init__(self, - event_shape, - context_shape, - transformer: ScalarTransformer, + event_shape: Union[Tuple[int, ...], torch.Size], + transformer_class: Type[ScalarTransformer], + context_shape: Union[Tuple[int, ...], torch.Size] = None, + transformer_kwargs: dict = None, + conditioner_kwargs: dict = None, **kwargs): + conditioner_kwargs = conditioner_kwargs or {} + transformer_kwargs = transformer_kwargs or {} + transformer = transformer_class(event_shape=event_shape, **transformer_kwargs) conditioner_transform = MADE( input_event_shape=event_shape, transformed_event_shape=event_shape, parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, - **kwargs + **conditioner_kwargs ) - super().__init__(transformer.event_shape, transformer, conditioner_transform) + super().__init__(transformer.event_shape, transformer, conditioner_transform, **kwargs) def apply_conditioner_transformer(self, inputs, context, forward: bool = True): h = self.conditioner_transform(inputs, context) diff --git a/torchflows/bijections/finite/autoregressive/transformers/base.py b/torchflows/bijections/finite/autoregressive/transformers/base.py index c0e6f3f..fc9858e 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/base.py +++ b/torchflows/bijections/finite/autoregressive/transformers/base.py @@ -56,8 +56,8 @@ def default_parameters(self) -> torch.Tensor: class ScalarTransformer(TensorTransformer): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) @property def parameter_shape_per_element(self): From a510134a2ac215b5b165e174878f90b765dcf577 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 02:29:47 +0200 Subject: [PATCH 12/25] Elementwise bijection cleanup/refactor --- .../finite/autoregressive/layers.py | 26 ++++++++----------- .../finite/autoregressive/layers_base.py | 18 +++++++++---- .../transformers/linear/affine.py | 2 +- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index 6e2133a..be33e97 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -21,10 +21,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to Affine. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Affine(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, Affine, **kwargs) class ElementwiseInverseAffine(ElementwiseBijection): @@ -32,10 +31,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to InverseAffine. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = InverseAffine(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, InverseAffine, **kwargs) class ActNorm(ElementwiseInverseAffine): @@ -75,20 +73,19 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to Scale. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Scale(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, Scale, **kwargs) class ElementwiseShift(ElementwiseBijection): - def __init__(self, event_shape): + def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Shift(event_shape) - super().__init__(transformer) + super().__init__(event_shape, Shift, **kwargs) class ElementwiseRQSpline(ElementwiseBijection): @@ -96,10 +93,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to RationalQuadratic. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = RationalQuadratic(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, RationalQuadratic, **kwargs) class AffineCoupling(CouplingBijection): diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index 4e7cdc4..9a47270 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Type +from typing import Tuple, Union, Type, Optional import torch import torch.nn as nn @@ -15,7 +15,7 @@ class AutoregressiveBijection(Bijection): def __init__(self, event_shape, transformer: Union[TensorTransformer, ScalarTransformer], - conditioner_transform: ConditionerTransform, + conditioner_transform: Optional[ConditionerTransform], **kwargs): super().__init__(event_shape=event_shape) self.conditioner_transform = conditioner_transform @@ -205,11 +205,19 @@ class ElementwiseBijection(AutoregressiveBijection): The bijection for each element has its own set of globally learned parameters. """ - def __init__(self, transformer: ScalarTransformer, fill_value: float = None): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size], + transformer_class: Type[ScalarTransformer], + transformer_kwargs: dict = None, + fill_value: float = None, + **kwargs): + transformer_kwargs = transformer_kwargs or {} + transformer = transformer_class(event_shape=event_shape, **transformer_kwargs) super().__init__( - transformer.event_shape, + event_shape, transformer, - None + None, + **kwargs ) if fill_value is None: diff --git a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py index 5c0e9a4..13ab46e 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py +++ b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py @@ -135,7 +135,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch class Shift(ScalarTransformer): - def __init__(self, event_shape: torch.Size): + def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape=event_shape) @property From 3bc85b08588c7fce46df3c25dad7db0c7ba0fd04 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 03:01:40 +0200 Subject: [PATCH 13/25] Refactor classic residual bijections --- test/test_autograd_bijections.py | 5 -- test/test_invert_classic_residual.py | 26 ++++++ torchflows/__init__.py | 2 +- torchflows/bijections/base.py | 3 + .../finite/residual/architectures.py | 85 +++++++------------ torchflows/bijections/finite/residual/base.py | 42 ++++++--- .../bijections/finite/residual/planar.py | 20 +---- .../bijections/finite/residual/radial.py | 8 +- .../bijections/finite/residual/sylvester.py | 21 ++--- 9 files changed, 108 insertions(+), 104 deletions(-) create mode 100644 test/test_invert_classic_residual.py diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index 6b4069f..81ac954 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -11,11 +11,6 @@ LRSCoupling, LinearRQSCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow -from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar -from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.radial import Radial -from torchflows.bijections.finite.residual.sylvester import Sylvester from torchflows.utils import get_batch_shape from test.constants import __test_constants diff --git a/test/test_invert_classic_residual.py b/test/test_invert_classic_residual.py new file mode 100644 index 0000000..f6656f4 --- /dev/null +++ b/test/test_invert_classic_residual.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torchflows.flows import Flow +from torchflows.bijections.finite.residual.architectures import RadialFlow, SylvesterFlow, PlanarFlow + + +@pytest.mark.parametrize( + 'architecture_class', + [ + RadialFlow, + SylvesterFlow, + PlanarFlow + ] +) +def test_basic(architecture_class): + torch.manual_seed(0) + event_shape = (1, 2, 3, 4) + batch_shape = (5, 6) + + flow = Flow(architecture_class(event_shape)) + x_new = flow.sample(batch_shape) + assert x_new.shape == (*batch_shape, *event_shape) + + flow.bijection.invert() + assert flow.log_prob(x_new).shape == batch_shape diff --git a/torchflows/__init__.py b/torchflows/__init__.py index 6d3802c..1a7460a 100644 --- a/torchflows/__init__.py +++ b/torchflows/__init__.py @@ -46,7 +46,7 @@ 'ProximalResFlow', 'Radial', 'Planar', - 'Sylvester', + 'InverseSylvester', 'ElementwiseShift', 'ElementwiseAffine', 'ElementwiseRQSpline', diff --git a/torchflows/bijections/base.py b/torchflows/bijections/base.py index 57c0ca6..23b0164 100644 --- a/torchflows/bijections/base.py +++ b/torchflows/bijections/base.py @@ -98,6 +98,9 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor def regularization(self): return 0.0 + def invert(self): + self.forward, self.inverse = self.inverse, self.forward + def invert(bijection: Bijection) -> Bijection: """Swap the forward and inverse methods of the input bijection. diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index d14e392..9d363bc 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -2,59 +2,49 @@ import torch -from torchflows.bijections.base import BijectiveComposition -from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine -from torchflows.bijections.finite.residual.base import ResidualComposition +from torchflows.bijections.finite.residual.base import ResidualArchitecture from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar, InversePlanar +from torchflows.bijections.finite.residual.planar import Planar from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester -class InvertibleResNet(ResidualComposition): +class InvertibleResNet(ResidualArchitecture): """Invertible residual network (i-ResNet) architecture. Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, InvertibleResNetBlock, **kwargs) -class ResFlow(ResidualComposition): +class ResFlow(ResidualArchitecture): """Residual flow (ResFlow) architecture. Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, ResFlowBlock, **kwargs) -class ProximalResFlow(ResidualComposition): +class ProximalResFlow(ResidualArchitecture): """Proximal residual flow architecture. Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): - blocks = [ - ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs) - for _ in range(n_layers) - ] - super().__init__(blocks) + def __init__(self, event_shape, **kwargs): + if 'layer_kwargs' not in kwargs: + kwargs['layer_kwargs'] = {} + if 'gamma' not in kwargs['layer_kwargs']: + kwargs['layer_kwargs']['gamma'] = 0.01 + super().__init__(event_shape, ProximalResFlowBlock, **kwargs) -class PlanarFlow(BijectiveComposition): +class PlanarFlow(ResidualArchitecture): """Planar flow architecture. Note: this model currently supports only one-way transformations. @@ -64,18 +54,11 @@ class PlanarFlow(BijectiveComposition): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - n_layers: int = 2, - inverse: bool = True): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[(InversePlanar if inverse else Planar)(event_shape) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) - - -class RadialFlow(BijectiveComposition): + **kwargs): + super().__init__(event_shape, Planar, **kwargs) + + +class RadialFlow(ResidualArchitecture): """Radial flow architecture. Note: this model currently supports only one-way transformations. @@ -83,17 +66,13 @@ class RadialFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[Radial(event_shape) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + **kwargs): + super().__init__(event_shape, Radial, **kwargs) -class SylvesterFlow(BijectiveComposition): +class SylvesterFlow(ResidualArchitecture): """Sylvester flow architecture. Note: this model currently supports only one-way transformations. @@ -101,11 +80,7 @@ class SylvesterFlow(BijectiveComposition): Reference: Van den Berg et al. "Sylvester Normalizing Flows for Variational Inference" (2019); https://arxiv.org/abs/1803.05649. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs): - if n_layers < 1: - raise ValueError(f"Flow needs at least one layer, but got {n_layers}") - super().__init__(event_shape, [ - ElementwiseAffine(event_shape), - *[Sylvester(event_shape, **kwargs) for _ in range(n_layers)], - ElementwiseAffine(event_shape) - ]) + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + **kwargs): + super().__init__(event_shape, Sylvester, **kwargs) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index 55b0e61..9ac4e13 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, List +from typing import Union, Tuple, List, Type import torch @@ -7,8 +7,18 @@ from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch +class ClassicResidualBijection(Bijection): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + inverse: bool = False, + **kwargs): + super().__init__(event_shape, **kwargs) + if inverse: + self.invert() + + class ResidualBijection(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): """ g maps from (*batch_shape, n_event_dims) to (*batch_shape, n_event_dims) @@ -65,18 +75,24 @@ def inverse(self, return x, log_det -class ResidualComposition(BijectiveComposition): - def __init__(self, blocks: List[ResidualBijection]): - assert len(blocks) > 0 - event_shape = blocks[0].event_shape +class ResidualArchitecture(BijectiveComposition): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size], + layer_class: Type[Union[ResidualBijection, ClassicResidualBijection]], + n_layers: int = 2, + layer_kwargs: dict = None, + **kwargs): + assert n_layers > 0 + layer_kwargs = layer_kwargs or {} - updated_layers = [ElementwiseAffine(event_shape)] - for i in range(len(blocks)): - updated_layers.append(blocks[i]) - updated_layers.append(ElementwiseAffine(event_shape)) + layers = [ElementwiseAffine(event_shape)] + for i in range(n_layers): + layers.append(layer_class(event_shape, **layer_kwargs)) + layers.append(ElementwiseAffine(event_shape)) super().__init__( - event_shape=updated_layers[0].event_shape, - layers=updated_layers, - context_shape=updated_layers[0].context_shape + event_shape=layers[0].event_shape, + layers=layers, + context_shape=layers[0].context_shape, + **kwargs ) diff --git a/torchflows/bijections/finite/residual/planar.py b/torchflows/bijections/finite/residual/planar.py index 1edc957..80e061b 100644 --- a/torchflows/bijections/finite/residual/planar.py +++ b/torchflows/bijections/finite/residual/planar.py @@ -2,25 +2,13 @@ import torch import torch.nn as nn -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.utils import get_batch_shape -class Planar(Bijection): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.inv_planar = InversePlanar(*args, **kwargs) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.inv_planar.inverse(z=x, context=context) - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.inv_planar.forward(x=z, context=context) - - -class InversePlanar(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape) +class Planar(ClassicResidualBijection): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.w = nn.Parameter(torch.randn(size=(self.n_dim,))) self.u = nn.Parameter(torch.randn(size=(self.n_dim,))) self.b = nn.Parameter(torch.randn(size=())) diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 385e291..47eb1c3 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -4,15 +4,15 @@ import torch.nn as nn from torch.nn.functional import softplus -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.utils import get_batch_shape -class Radial(Bijection): +class Radial(ClassicResidualBijection): # as per Rezende, Mohamed (2015) - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.beta = nn.Parameter(torch.randn(size=())) self.unconstrained_alpha = nn.Parameter(torch.randn(size=())) self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,))) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 5208448..9a39aca 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -3,17 +3,18 @@ import torch import torch.nn as nn -from torchflows.bijections.base import Bijection +from torchflows.bijections.finite.residual.base import ClassicResidualBijection from torchflows.bijections.matrices import UpperTriangularInvertibleMatrix, HouseholderOrthogonalMatrix, \ IdentityMatrix, PermutationMatrix from torchflows.utils import get_batch_shape -class BaseSylvester(Bijection): +class BaseSylvester(ClassicResidualBijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - m: int = None): - super().__init__(event_shape) + m: int = None, + **kwargs): + super().__init__(event_shape, **kwargs) self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) if m is None: @@ -75,14 +76,14 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class HouseholderSylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m) class IdentitySylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = IdentityMatrix(n_dim=self.n_dim) @@ -90,6 +91,6 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = Non class PermutationSylvester(BaseSylvester): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], m: int = None): - super().__init__(event_shape, m) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) self.q = PermutationMatrix(n_dim=self.n_dim) From f3b22ca8d9adbc7011a0aa02ab3c45368a3fa7b8 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 1 Sep 2024 22:52:43 +0200 Subject: [PATCH 14/25] Add convolutional residual flows and refactor residual event shapes --- test/test_convolutional_architectures.py | 21 ++- test/test_stochastic_log_det_estimation.py | 9 +- torchflows/architectures.py | 4 +- .../finite/residual/architectures.py | 27 +++- torchflows/bijections/finite/residual/base.py | 26 ++-- .../bijections/finite/residual/iterative.py | 133 ++++++++++++++---- .../finite/residual/log_abs_det_estimators.py | 56 +++++--- .../bijections/finite/residual/proximal.py | 19 ++- .../finite/residual/quasi_autoregressive.py | 4 +- 9 files changed, 224 insertions(+), 75 deletions(-) diff --git a/test/test_convolutional_architectures.py b/test/test_convolutional_architectures.py index 166765f..d8a178e 100644 --- a/test/test_convolutional_architectures.py +++ b/test/test_convolutional_architectures.py @@ -14,6 +14,7 @@ import torch import pytest from test.constants import __test_constants +from torchflows.bijections.finite.residual.architectures import ConvolutionalInvertibleResNet, ConvolutionalResFlow @pytest.mark.parametrize('architecture_class', [ @@ -52,7 +53,25 @@ def test_continuous(architecture_class, image_shape): xr, ldi = bijection.inverse(z) assert x.shape == xr.shape assert ldf.shape == ldi.shape - assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x-xr).abs().max()}"' + assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x - xr).abs().max()}"' + assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2 + + +@pytest.mark.skip('Unsupported/failing') +@pytest.mark.parametrize('architecture_class', [ + ConvolutionalInvertibleResNet, + ConvolutionalResFlow +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_residual(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert x.shape == xr.shape + assert ldf.shape == ldi.shape + assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x - xr).abs().max()}"' assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2 diff --git a/test/test_stochastic_log_det_estimation.py b/test/test_stochastic_log_det_estimation.py index af55655..d9b3a5f 100644 --- a/test/test_stochastic_log_det_estimation.py +++ b/test/test_stochastic_log_det_estimation.py @@ -48,6 +48,7 @@ def test_power_series_estimator(n_iterations, n_hutchinson_samples): torch.manual_seed(0) x = torch.randn(size=(n_data, n_dim)) g_value, log_det_f_estimated = log_det_power_series( + (n_dim,), g, x, training=False, @@ -76,7 +77,13 @@ def test_roulette_estimator(p): torch.manual_seed(0) x = torch.randn(size=(n_data, n_dim)) - g_value, log_det_f = log_det_roulette(g, x, training=False, p=p) + g_value, log_det_f = log_det_roulette( + (n_dim,), + g, + x, + training=False, + p=p + ) log_det_f_true = test_data.log_det_jac_f(x).ravel() print(f'{log_det_f = }') diff --git a/torchflows/architectures.py b/torchflows/architectures.py index f667bb9..a737ed0 100644 --- a/torchflows/architectures.py +++ b/torchflows/architectures.py @@ -33,7 +33,9 @@ InvertibleResNet, PlanarFlow, RadialFlow, - SylvesterFlow + SylvesterFlow, + ConvolutionalInvertibleResNet, + ConvolutionalResFlow ) from torchflows.bijections.finite.multiscale.architectures import ( diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index 9d363bc..7666d06 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -3,7 +3,12 @@ import torch from torchflows.bijections.finite.residual.base import ResidualArchitecture -from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock +from torchflows.bijections.finite.residual.iterative import ( + InvertibleResNetBlock, + ResFlowBlock, + ConvolutionalInvertibleResNetBlock, + ConvolutionalResFlowBlock +) from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock from torchflows.bijections.finite.residual.planar import Planar from torchflows.bijections.finite.residual.radial import Radial @@ -20,6 +25,16 @@ def __init__(self, event_shape, **kwargs): super().__init__(event_shape, InvertibleResNetBlock, **kwargs) +class ConvolutionalInvertibleResNet(ResidualArchitecture): + """Convolutional variant of i-ResNet. + + Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. + """ + + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, ConvolutionalInvertibleResNetBlock, **kwargs) + + class ResFlow(ResidualArchitecture): """Residual flow (ResFlow) architecture. @@ -30,6 +45,16 @@ def __init__(self, event_shape, **kwargs): super().__init__(event_shape, ResFlowBlock, **kwargs) +class ConvolutionalResFlow(ResidualArchitecture): + """Convolutional variant of ResFlow. + + Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. + """ + + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, ConvolutionalResFlowBlock, **kwargs) + + class ProximalResFlow(ResidualArchitecture): """Proximal residual flow architecture. diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index 9ac4e13..f9113de 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -17,14 +17,12 @@ def __init__(self, self.invert() -class ResidualBijection(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - """ - - g maps from (*batch_shape, n_event_dims) to (*batch_shape, n_event_dims) +class IterativeResidualBijection(Bijection): + """ + g maps from (*batch_shape, *event_shape) to (*batch_shape, *event_shape) + """ - :param event_shape: - """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape) self.g: callable = None @@ -36,16 +34,16 @@ def forward(self, context: torch.Tensor = None, skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(x, self.event_shape) - x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + x_flat = flatten_batch(x, batch_shape) g_flat = self.g(x_flat) - g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + g = unflatten_batch(g_flat, batch_shape) z = x + g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) + x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) @@ -59,16 +57,16 @@ def inverse(self, batch_shape = get_batch_shape(z, self.event_shape) x = z for _ in range(n_iterations): - x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + x_flat = flatten_batch(x, batch_shape) g_flat = self.g(x_flat) - g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + g = unflatten_batch(g_flat, batch_shape) x = z - g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) + x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) @@ -78,7 +76,7 @@ def inverse(self, class ResidualArchitecture(BijectiveComposition): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - layer_class: Type[Union[ResidualBijection, ClassicResidualBijection]], + layer_class: Type[Union[IterativeResidualBijection, ClassicResidualBijection]], n_layers: int = 2, layer_kwargs: dict = None, **kwargs): diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py index 9e7bbd2..ef8e3fa 100644 --- a/torchflows/bijections/finite/residual/iterative.py +++ b/torchflows/bijections/finite/residual/iterative.py @@ -4,19 +4,16 @@ import torch import torch.nn as nn -from torchflows.bijections.finite.residual.base import ResidualBijection +from torchflows.bijections.finite.residual.base import IterativeResidualBijection from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_power_series, log_det_roulette +from torchflows.utils import get_batch_shape -class SpectralLinear(nn.Module): - # https://arxiv.org/pdf/1811.00995.pdf - - def __init__(self, n_inputs: int, n_outputs: int, c: float = 0.7, n_iterations: int = 5): +class SpectralMatrix(nn.Module): + def __init__(self, shape: Tuple[int, int], c: float = 0.7, n_iterations: int = 5): super().__init__() + self.data = torch.randn(size=shape) self.c = c - self.n_inputs = n_inputs - self.w = torch.randn(n_outputs, n_inputs) - self.bias = nn.Parameter(torch.randn(n_outputs)) self.n_iterations = n_iterations @torch.no_grad() @@ -25,7 +22,7 @@ def power_iteration(self, w): # Spectral Normalization for Generative Adversarial Networks - Miyato et al. - 2018 # Get maximum singular value of rectangular matrix w - u = torch.randn(self.n_inputs, 1) + u = torch.randn(self.data.shape[1], 1) v = None w = w.T @@ -40,50 +37,128 @@ def power_iteration(self, w): factor = u.T @ w @ v return factor - @property - def normalized_mat(self): + def normalized(self): # Estimate sigma - sigma = self.power_iteration(self.w) + sigma = self.power_iteration(self.data) # ratio = self.c / sigma # return self.w * (ratio ** (ratio < 1)) - return self.w / sigma + return self.data / sigma + + +class SpectralLinear(nn.Module): + # https://arxiv.org/pdf/1811.00995.pdf + + def __init__(self, n_inputs: int, n_outputs: int, **kwargs): + super().__init__() + self.w = SpectralMatrix((n_outputs, n_inputs), **kwargs) + self.bias = nn.Parameter(torch.randn(n_outputs)) def forward(self, x): - return torch.nn.functional.linear(x, self.normalized_mat, self.bias) + return torch.nn.functional.linear(x, self.w.normalized(), self.bias) -class SpectralNeuralNetwork(nn.Sequential): - def __init__(self, n_dim: int, n_hidden: int = None, n_hidden_layers: int = 1, **kwargs): +class SpectralConv2d(nn.Module): + def __init__(self, n_channels: int, kernel_shape: Tuple[int, int] = (3, 3), **kwargs): + super().__init__() + self.n_channels = n_channels + self.kernel_shape = kernel_shape + self.weight = SpectralMatrix((n_channels * kernel_shape[0], n_channels * kernel_shape[1]), **kwargs) + self.bias = nn.Parameter(torch.randn(n_channels)) + + def forward(self, x): + w = self.weight.normalized().view(self.n_channels, self.n_channels, *self.kernel_shape) + return torch.conv2d(x, w, self.bias, padding='same') + + +class SpectralNeuralNetwork(nn.Module): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size], + n_hidden: int = None, + n_hidden_layers: int = 1, + **kwargs): + self.event_shape = event_shape + event_size = int(torch.prod(torch.as_tensor(event_shape))) if n_hidden is None: - n_hidden = int(3 * max(math.log(n_dim), 4)) + n_hidden = int(3 * max(math.log(event_size), 4)) - layers = [] if n_hidden_layers == 0: - layers = [SpectralLinear(n_dim, n_dim, **kwargs)] + layers = [SpectralLinear(event_size, event_size, **kwargs)] else: - layers.append(SpectralLinear(n_dim, n_hidden, **kwargs)) + layers = [SpectralLinear(event_size, n_hidden, **kwargs)] for _ in range(n_hidden): layers.append(nn.Tanh()) layers.append(SpectralLinear(n_hidden, n_hidden, **kwargs)) layers.pop(-1) - layers.append(SpectralLinear(n_hidden, n_dim, **kwargs)) + layers.append(SpectralLinear(n_hidden, event_size, **kwargs)) + super().__init__() + self.layers = nn.ModuleList(layers) + + def forward(self, x): + batch_shape = get_batch_shape(x, self.event_shape) + x_flat = x.view(*batch_shape, -1) + for layer in self.layers: + x_flat = layer(x_flat) + x = x_flat.view_as(x) + return x + + +class SpectralCNN(nn.Sequential): + def __init__(self, n_channels: int, n_layers: int = 2, **kwargs): + layers = [] + for _ in range(n_layers): + layers.append(SpectralConv2d(n_channels, **kwargs)) + layers.append(nn.Tanh()) + layers.pop(-1) super().__init__(*layers) -class InvertibleResNetBlock(ResidualBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape=None, **kwargs): +class InvertibleResNetBlock(IterativeResidualBijection): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + g: nn.Module = None, + **kwargs): + # TODO add context super().__init__(event_shape) - self.g = SpectralNeuralNetwork(n_dim=self.n_dim, **kwargs) + if g is None: + g = SpectralNeuralNetwork(event_shape, **kwargs) + self.g = g def log_det(self, x: torch.Tensor, **kwargs): - return log_det_power_series(self.g, x, n_iterations=2, **kwargs)[1] + return log_det_power_series(self.event_shape, self.g, x, n_iterations=2, **kwargs)[1] -class ResFlowBlock(InvertibleResNetBlock): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape=None, p: float = 0.5, **kwargs): +class ConvolutionalInvertibleResNetBlock(InvertibleResNetBlock): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): # TODO add context + super().__init__(event_shape, g=SpectralCNN(n_channels=event_shape[0]), **kwargs) + + +class ResFlowBlock(IterativeResidualBijection): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + g: nn.Module = None, + p: float = 0.5, + **kwargs): + # TODO add context + super().__init__(event_shape) + if g is None: + g = SpectralNeuralNetwork(event_shape, **kwargs) + self.g = g self.p = p - super().__init__(event_shape, **kwargs) def log_det(self, x: torch.Tensor, **kwargs): - return log_det_roulette(self.g, x, p=self.p, **kwargs)[1] + return log_det_roulette(self.event_shape, self.g, x, p=self.p, **kwargs)[1] + + +class ConvolutionalResFlowBlock(ResFlowBlock): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): + # TODO add context + super().__init__(event_shape, g=SpectralCNN(n_channels=event_shape[0]), **kwargs) diff --git a/torchflows/bijections/finite/residual/log_abs_det_estimators.py b/torchflows/bijections/finite/residual/log_abs_det_estimators.py index b9d99d6..acd9340 100644 --- a/torchflows/bijections/finite/residual/log_abs_det_estimators.py +++ b/torchflows/bijections/finite/residual/log_abs_det_estimators.py @@ -1,7 +1,9 @@ +from typing import Tuple, Union + import torch import torch.nn as nn -from torchflows.utils import Geometric, vjp_tensor +from torchflows.utils import Geometric, flatten_batch, get_batch_shape def power_series_log_abs_det_estimator(g: callable, @@ -10,14 +12,13 @@ def power_series_log_abs_det_estimator(g: callable, training: bool, n_iterations: int = 8): # f(x) = x + g(x) - # x.shape == (batch_size, event_size) - # noise.shape == (batch_size, event_size, n_hutchinson_samples) - # g(x).shape == (batch_size, event_size) - - assert len(noise.shape) == 3 - batch_size, event_size, n_hutchinson_samples = noise.shape - assert len(x.shape) == 2 - assert x.shape == (batch_size, event_size) + # x.shape == (batch_size, *event_shape) + # noise.shape == (batch_size, *event_shape, n_hutchinson_samples) + # g(x).shape == (batch_size, *event_shape) + + batch_size, *event_shape, n_hutchinson_samples = noise.shape + event_dims = tuple(range(1, len(x.shape))) + assert x.shape == (batch_size, *event_shape) assert n_iterations >= 2 w = noise # (batch_size, event_size, n_hutchinson_samples) @@ -27,18 +28,21 @@ def power_series_log_abs_det_estimator(g: callable, # Compute VJP, reshape appropriately for hutchinson averaging gs_r, ws_r = torch.autograd.functional.vjp( g, - x[..., None].repeat(1, 1, n_hutchinson_samples).view(batch_size * n_hutchinson_samples, event_size), - w.view(batch_size * n_hutchinson_samples, event_size), + x[..., None].repeat(*([1] * (len(event_shape) + 1)), n_hutchinson_samples).view( + batch_size * n_hutchinson_samples, + *event_shape + ), + w.view(batch_size * n_hutchinson_samples, *event_shape), create_graph=training ) if g_value is None: - g_value = gs_r.view(batch_size, event_size, n_hutchinson_samples)[..., 0] + g_value = gs_r.view(batch_size, *event_shape, n_hutchinson_samples)[..., 0] - w = ws_r.view(batch_size, event_size, n_hutchinson_samples) + w = ws_r.view(batch_size, *event_shape, n_hutchinson_samples) - # sum over event dim, average over hutchinson dim - log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1).mean(dim=1) + # sum over event dims, average over hutchinson dim + log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=event_dims).mean(dim=1) assert log_abs_det_jac_f.shape == (batch_size,) return g_value, log_abs_det_jac_f @@ -58,6 +62,7 @@ def roulette_log_abs_det_estimator(g: callable, :return: """ # f(x) = x + g(x) + event_dims = tuple(range(1, len(x.shape))) w = noise neumann_vjp = noise dist = Geometric(probs=torch.tensor(p), minimum=1) @@ -71,7 +76,7 @@ def roulette_log_abs_det_estimator(g: callable, neumann_vjp = neumann_vjp + (-1) ** k / (k * p_k) * w g_value, vjp_jac = torch.autograd.functional.vjp(g, x, neumann_vjp, create_graph=training) # vjp_jac = torch.autograd.grad(g_value, x, neumann_vjp, create_graph=training)[0] - log_abs_det_jac_f = torch.sum(vjp_jac * noise, dim=-1) + log_abs_det_jac_f = torch.sum(vjp_jac * noise, dim=event_dims) return g_value, log_abs_det_jac_f @@ -127,7 +132,8 @@ def backward(ctx, grad_g, grad_logdetgrad): g_params = params_and_grad[:len(params_and_grad) // 2] grad_params = params_and_grad[len(params_and_grad) // 2:] - dg_x, *dg_params = torch.autograd.grad(g_value, [x] + g_params, grad_g, allow_unused=True, retain_graph=training) + dg_x, *dg_params = torch.autograd.grad(g_value, [x] + g_params, grad_g, allow_unused=True, + retain_graph=training) # Update based on gradient from log determinant. dL = grad_logdetgrad[0].detach() @@ -143,7 +149,13 @@ def backward(ctx, grad_g, grad_logdetgrad): return (None, None, grad_x, None, None) + grad_params -def log_det_roulette(g: nn.Module, x: torch.Tensor, training: bool = False, p: float = 0.5): +def log_det_roulette(event_shape: Union[torch.Size, Tuple[int, ...]], + g: nn.Module, + x: torch.Tensor, + training: bool = False, + p: float = 0.5): + batch_shape = get_batch_shape(x, event_shape) + x = flatten_batch(x, batch_shape) noise = torch.randn_like(x) return LogDeterminantEstimator.apply( lambda *args, **kwargs: roulette_log_abs_det_estimator(*args, **kwargs, p=p), @@ -155,8 +167,14 @@ def log_det_roulette(g: nn.Module, x: torch.Tensor, training: bool = False, p: f ) -def log_det_power_series(g: nn.Module, x: torch.Tensor, training: bool = False, n_iterations: int = 8, +def log_det_power_series(event_shape: Union[torch.Size, Tuple[int, ...]], + g: nn.Module, + x: torch.Tensor, + training: bool = False, + n_iterations: int = 8, n_hutchinson_samples: int = 1): + batch_shape = get_batch_shape(x, event_shape) + x = flatten_batch(x, batch_shape) noise = torch.randn(size=(*x.shape, n_hutchinson_samples)) return LogDeterminantEstimator.apply( lambda *args, **kwargs: power_series_log_abs_det_estimator(*args, **kwargs, n_iterations=n_iterations), diff --git a/torchflows/bijections/finite/residual/proximal.py b/torchflows/bijections/finite/residual/proximal.py index 7b71d44..5e5dda9 100644 --- a/torchflows/bijections/finite/residual/proximal.py +++ b/torchflows/bijections/finite/residual/proximal.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torchflows.bijections.finite.residual.base import ResidualBijection +from torchflows.bijections.finite.residual.base import IterativeResidualBijection from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_roulette @@ -93,11 +93,13 @@ def regularization(self): def forward(self, x): """ - x.shape = (batch_size, event_size) + x.shape = (batch_size, *event_shape) """ + x_flat = x.view(x.shape[0], -1) mat = self.stiefel_matrix - act = self.act(torch.nn.functional.linear(x, mat, self.b)) - return torch.einsum('...ij,...kj->...ki', mat.T, act) + act = self.act(torch.nn.functional.linear(x_flat, mat, self.b)) + out = torch.einsum('...ij,...kj->...ki', mat.T, act) + return out.view_as(x) class PNN(nn.Sequential): @@ -128,7 +130,9 @@ def __init__(self, pnn: PNN, gamma: float, max_gamma: float): self.phi = pnn def r(self, x): - return 1 / self.phi.t * (self.phi(x) - (1 - self.phi.t) * x) + x_flat = x.view(x.shape[0], -1) + out = 1 / self.phi.t * (self.phi(x_flat) - (1 - self.phi.t) * x_flat) + return out.view_as(x) def forward(self, x): const = self.gamma * self.phi.t / (1 + self.gamma - self.gamma * self.phi.t) @@ -141,12 +145,13 @@ def log_det_single_layer(self, x): mat = layer.stiefel_matrix b = layer.b - act_derivative = layer.act.derivative(torch.nn.functional.linear(x, mat, b)) + x_flat = x.view(x.shape[0], -1) + act_derivative = layer.act.derivative(torch.nn.functional.linear(x_flat, mat, b)) log_derivatives = torch.log1p(self.gamma * act_derivative) return torch.sum(log_derivatives, dim=-1) -class ProximalResFlowBlock(ResidualBijection): +class ProximalResFlowBlock(IterativeResidualBijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, diff --git a/torchflows/bijections/finite/residual/quasi_autoregressive.py b/torchflows/bijections/finite/residual/quasi_autoregressive.py index ea95977..714a454 100644 --- a/torchflows/bijections/finite/residual/quasi_autoregressive.py +++ b/torchflows/bijections/finite/residual/quasi_autoregressive.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from torchflows.bijections.finite.residual.base import ResidualBijection +from torchflows.bijections.finite.residual.base import IterativeResidualBijection class AffineQuasiMADELayerSet(nn.Module): @@ -77,7 +77,7 @@ def forward(self, x): return x, jac -class QuasiAutoregressiveFlowBlock(ResidualBijection): +class QuasiAutoregressiveFlowBlock(IterativeResidualBijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], sigma: float = 0.7): super().__init__(event_shape) self.log_theta = nn.Parameter(torch.randn()) From b2c09c248dfbfa5d93dbb9f1f81ec83e9963b576 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 2 Sep 2024 20:54:29 +0200 Subject: [PATCH 15/25] Add Glow architectures --- .../finite/multiscale/architectures.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/torchflows/bijections/finite/multiscale/architectures.py b/torchflows/bijections/finite/multiscale/architectures.py index 90191cc..f9d6d88 100644 --- a/torchflows/bijections/finite/multiscale/architectures.py +++ b/torchflows/bijections/finite/multiscale/architectures.py @@ -244,3 +244,54 @@ def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layer n_blocks=n_layers, **kwargs ) + + +class DenseSigmoidGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DenseSigmoid, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) + + +class DeepSigmoidGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DeepSigmoid, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) + + +class DeepDenseSigmoidGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DeepDenseSigmoid, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) From 14007d7eaa3e0047a6e93d248177c8c1d827b2ce Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 2 Sep 2024 20:59:03 +0200 Subject: [PATCH 16/25] Add architectures to __init__ --- torchflows/architectures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchflows/architectures.py b/torchflows/architectures.py index a737ed0..13171f0 100644 --- a/torchflows/architectures.py +++ b/torchflows/architectures.py @@ -49,5 +49,8 @@ AffineGlow, RQSGlow, LRSGlow, - ShiftGlow + ShiftGlow, + DeepSigmoidGlow, + DeepDenseSigmoidGlow, + DenseSigmoidGlow ) From c5b681b2034439789b8e758dc441a0b724a2487d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 3 Sep 2024 00:11:09 +0200 Subject: [PATCH 17/25] Remove unused imports --- torchflows/bijections/finite/residual/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index f9113de..d664b48 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -1,10 +1,10 @@ -from typing import Union, Tuple, List, Type +from typing import Union, Tuple, Type import torch from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine from torchflows.bijections.base import Bijection, BijectiveComposition -from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch +from torchflows.utils import get_batch_shape, flatten_batch, unflatten_batch class ClassicResidualBijection(Bijection): From 283ab05a5ffe61d4b14674371ed6390caed560dd Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 3 Sep 2024 17:35:50 +0200 Subject: [PATCH 18/25] Add maximum batch size limit in megabytes for adaptive batch size; add divergence checks for SVI --- torchflows/flows.py | 79 +++++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/torchflows/flows.py b/torchflows/flows.py index 22bde3c..8bb31bb 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -79,7 +79,8 @@ def fit(self, context_val: torch.Tensor = None, keep_best_weights: bool = True, early_stopping: bool = False, - early_stopping_threshold: int = 50): + early_stopping_threshold: int = 50, + max_batch_size_mb: int = 2000): """Fit the normalizing flow to a dataset. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. @@ -100,6 +101,7 @@ def fit(self, :param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data. :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. + :param int max_batch_size_mb: maximum batch size in megabytes. """ if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing @@ -114,6 +116,10 @@ def fit(self, elif isinstance(batch_size, str) and batch_size == "adaptive": min_batch_size = max(32, min(1024, len(x_train) // 100)) max_batch_size = min(4096, len(x_train) // 10) + + event_size_mb = self.event_size / 2 ** 20 + max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb))) + batch_size_adaptation_interval = 10 # double the batch size every 10 epochs adaptive_batch_size = True batch_size = min_batch_size @@ -290,42 +296,53 @@ def variational_fit(self, print('Flow training diverged') print('Reverting to initial weights') break - - optimizer.zero_grad() - flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) - target_log_prob_value = target_log_prob(flow_x) - loss = -torch.mean(target_log_prob_value + flow_log_prob) - loss += self.regularization() - epoch_diverged = False - if check_for_divergences: - if not torch.isfinite(loss): - epoch_diverged = True - if torch.max(torch.abs(flow_x)) > 1e8: - epoch_diverged = True - elif torch.max(torch.abs(flow_log_prob)) > 1e6: - epoch_diverged = True - elif torch.any(~torch.isfinite(flow_x)): - epoch_diverged = True - elif torch.any(~torch.isfinite(flow_log_prob)): - epoch_diverged = True - n_divergences += epoch_diverged + optimizer.zero_grad() - if not epoch_diverged: - loss.backward() - optimizer.step() - if loss < best_loss: - best_loss = loss - best_epoch = epoch - if keep_best_weights: - best_weights = deepcopy(self.state_dict()) - else: + try: + flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) + target_log_prob_value = target_log_prob(flow_x) + loss = -torch.mean(target_log_prob_value + flow_log_prob) + loss += self.regularization() + + if check_for_divergences: + if not torch.isfinite(loss): + epoch_diverged = True + if torch.max(torch.abs(flow_x)) > 1e8: + epoch_diverged = True + elif torch.max(torch.abs(flow_log_prob)) > 1e6: + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_x)): + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_log_prob)): + epoch_diverged = True + + if not epoch_diverged: + loss.backward() + optimizer.step() + if loss < best_loss: + best_loss = loss + best_epoch = epoch + if keep_best_weights: + best_weights = deepcopy(self.state_dict()) + mean_flow_log_prob = flow_log_prob.mean() + mean_target_log_prob = target_log_prob_value.mean() + else: + loss = torch.nan + mean_flow_log_prob = torch.nan + mean_target_log_prob = torch.nan + except ValueError: + epoch_diverged = True loss = torch.nan + mean_flow_log_prob = torch.nan + mean_target_log_prob = torch.nan + + n_divergences += epoch_diverged pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}], ' f'divergences: {n_divergences}, ' - f'flow log_prob: {flow_log_prob.mean():.2f}, ' - f'target log_prob: {target_log_prob_value.mean():.2f}') + f'flow log_prob: {mean_flow_log_prob:.2f}, ' + f'target log_prob: {mean_target_log_prob:.2f}') if epoch - best_epoch > early_stopping_threshold and early_stopping: break From d1f43a819e524f2362bcf7fec9c823053a5ecadb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 3 Sep 2024 17:42:28 +0200 Subject: [PATCH 19/25] Set default max batch size in megabytes to None --- torchflows/flows.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchflows/flows.py b/torchflows/flows.py index 8bb31bb..79fc656 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -80,7 +80,7 @@ def fit(self, keep_best_weights: bool = True, early_stopping: bool = False, early_stopping_threshold: int = 50, - max_batch_size_mb: int = 2000): + max_batch_size_mb: int = None): """Fit the normalizing flow to a dataset. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. @@ -117,8 +117,9 @@ def fit(self, min_batch_size = max(32, min(1024, len(x_train) // 100)) max_batch_size = min(4096, len(x_train) // 10) - event_size_mb = self.event_size / 2 ** 20 - max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb))) + if max_batch_size_mb is not None: + event_size_mb = self.event_size / 2 ** 20 + max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb))) batch_size_adaptation_interval = 10 # double the batch size every 10 epochs adaptive_batch_size = True From 6232722ea476c40daac852c864a102be68369ca9 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 3 Sep 2024 22:28:04 +0200 Subject: [PATCH 20/25] Add maximum training time option for NF training --- torchflows/flows.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchflows/flows.py b/torchflows/flows.py index 79fc656..0c0f5a3 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -1,3 +1,4 @@ +import time from copy import deepcopy from typing import Union, Tuple, List @@ -80,7 +81,8 @@ def fit(self, keep_best_weights: bool = True, early_stopping: bool = False, early_stopping_threshold: int = 50, - max_batch_size_mb: int = None): + max_batch_size_mb: int = None, + time_limit_seconds: Union[float, int] = None): """Fit the normalizing flow to a dataset. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. @@ -102,7 +104,10 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. :param int max_batch_size_mb: maximum batch size in megabytes. + :param Union[float, int] time_limit_seconds: maximum allowed time for training. """ + t0 = time.time() + if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing return @@ -167,6 +172,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): val_loss = None for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress)): + if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds: + print("Training time limit exceeded") + break + if ( adaptive_batch_size and epoch % batch_size_adaptation_interval == batch_size_adaptation_interval - 1 @@ -263,7 +272,8 @@ def variational_fit(self, early_stopping_threshold: int = 50, keep_best_weights: bool = True, show_progress: bool = False, - check_for_divergences: bool = False): + check_for_divergences: bool = False, + time_limit_seconds:Union[float, int] = None): """Train the normalizing flow to fit a target log probability. Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. @@ -277,6 +287,8 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ + t0 = time.time() + if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing return @@ -292,6 +304,9 @@ def variational_fit(self, n_divergences = 0 for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)): + if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds: + print("Training time limit exceeded") + break if check_for_divergences and not all([torch.isfinite(p).all() for p in self.parameters()]): flow_training_diverged = True print('Flow training diverged') From 9094cdbe6cc4d6eee83ef094efda513043f439aa Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 4 Sep 2024 22:14:42 +0200 Subject: [PATCH 21/25] Fix residual flow fits --- test/test_fit_conv_residual_flow.py | 13 ++++++ torchflows/bijections/finite/residual/base.py | 4 +- .../bijections/finite/residual/iterative.py | 9 ++-- torchflows/flows.py | 43 +++++++++---------- 4 files changed, 40 insertions(+), 29 deletions(-) create mode 100644 test/test_fit_conv_residual_flow.py diff --git a/test/test_fit_conv_residual_flow.py b/test/test_fit_conv_residual_flow.py new file mode 100644 index 0000000..857700c --- /dev/null +++ b/test/test_fit_conv_residual_flow.py @@ -0,0 +1,13 @@ +import pytest +import torch + +from torchflows import Flow +from torchflows.architectures import ConvolutionalResFlow, ConvolutionalInvertibleResNet + + +@pytest.mark.parametrize('arch_cls', [ConvolutionalResFlow, ConvolutionalInvertibleResNet]) +def test_basic(arch_cls): + torch.manual_seed(0) + event_shape = (3, 20, 20) + flow = Flow(arch_cls(event_shape)) + flow.fit(torch.randn(size=(5, *event_shape)), n_epochs=20) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index d664b48..d745be9 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -45,7 +45,7 @@ def forward(self, else: x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) + log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return z, log_det @@ -68,7 +68,7 @@ def inverse(self, else: x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) + log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py index ef8e3fa..30843d4 100644 --- a/torchflows/bijections/finite/residual/iterative.py +++ b/torchflows/bijections/finite/residual/iterative.py @@ -12,7 +12,7 @@ class SpectralMatrix(nn.Module): def __init__(self, shape: Tuple[int, int], c: float = 0.7, n_iterations: int = 5): super().__init__() - self.data = torch.randn(size=shape) + self.data = nn.Parameter(torch.randn(size=shape)) self.c = c self.n_iterations = n_iterations @@ -22,7 +22,7 @@ def power_iteration(self, w): # Spectral Normalization for Generative Adversarial Networks - Miyato et al. - 2018 # Get maximum singular value of rectangular matrix w - u = torch.randn(self.data.shape[1], 1) + u = torch.randn(self.data.shape[1], 1).to(w) v = None w = w.T @@ -39,9 +39,8 @@ def power_iteration(self, w): def normalized(self): # Estimate sigma - sigma = self.power_iteration(self.data) - # ratio = self.c / sigma - # return self.w * (ratio ** (ratio < 1)) + with torch.no_grad(): + sigma = self.power_iteration(self.data) return self.data / sigma diff --git a/torchflows/flows.py b/torchflows/flows.py index 0c0f5a3..4713509 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -235,28 +235,27 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss at the end of each epoch # Validation loss will be displayed at the start of the next epoch if x_val is not None: - with torch.no_grad(): - # Compute validation loss - val_loss = 0.0 - for val_batch in val_loader: - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) - val_loss /= len(x_val) - val_loss += self.regularization() - - # Check if validation loss is the lowest so far - if val_loss < best_val_loss: - best_val_loss = val_loss - best_epoch = epoch - - # Store current weights - if keep_best_weights: - if best_epoch == epoch: - best_weights = deepcopy(self.state_dict()) - - # Optionally stop training early - if early_stopping: - if epoch - best_epoch > early_stopping_threshold: - break + # Compute validation loss + val_loss = 0.0 + for val_batch in val_loader: + val_loss += compute_batch_loss(val_batch, reduction=torch.sum).detach() + val_loss /= len(x_val) + val_loss += self.regularization() + + # Check if validation loss is the lowest so far + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + + # Store current weights + if keep_best_weights: + if best_epoch == epoch: + best_weights = deepcopy(self.state_dict()) + + # Optionally stop training early + if early_stopping: + if epoch - best_epoch > early_stopping_threshold: + break if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) From fc5456baf08babf41e00fe34783da1111db40d8b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 12 Sep 2024 21:45:17 +0200 Subject: [PATCH 22/25] Minor changes to continuous and residual NF parameter specification --- torchflows/bijections/continuous/ddnf.py | 21 +++++++++++++++---- torchflows/bijections/continuous/ffjord.py | 10 +++++---- torchflows/bijections/continuous/otflow.py | 15 +++++++++---- torchflows/bijections/continuous/rnode.py | 21 ++++++++++++++----- .../bijections/finite/residual/iterative.py | 14 ++++++------- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/torchflows/bijections/continuous/ddnf.py b/torchflows/bijections/continuous/ddnf.py index 3fe235b..d5ac497 100644 --- a/torchflows/bijections/continuous/ddnf.py +++ b/torchflows/bijections/continuous/ddnf.py @@ -21,26 +21,39 @@ class DeepDiffeomorphicBijection(ApproximateContinuousBijection): Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + n_steps: int = 150, + solver="euler", + nn_kwargs: dict = None, + **kwargs): """ Constructor. :param event_shape: shape of the event tensor. :param n_steps: parameter T in the paper, i.e. the number of ResNet cells. """ - diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(event_shape)) + nn_kwargs = nn_kwargs or {} + diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(event_shape, **nn_kwargs)) self.n_steps = n_steps super().__init__(event_shape, diff_eq, solver=solver, **kwargs) + class ConvolutionalDeepDiffeomorphicBijection(ApproximateContinuousBijection): """Convolutional variant of the DDNF architecture. Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + n_steps: int = 150, + solver="euler", + nn_kwargs: dict = None, + **kwargs): + nn_kwargs = nn_kwargs or {} if len(event_shape) != 3: raise ValueError("Event shape must be of length 3 (channels, height, width).") - diff_eq = RegularizedApproximateODEFunction(create_cnn_time_independent(event_shape[0])) + diff_eq = RegularizedApproximateODEFunction(create_cnn_time_independent(event_shape[0], **nn_kwargs)) self.n_steps = n_steps super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/ffjord.py b/torchflows/bijections/continuous/ffjord.py index a3aed56..3acd04c 100644 --- a/torchflows/bijections/continuous/ffjord.py +++ b/torchflows/bijections/continuous/ffjord.py @@ -18,8 +18,9 @@ class FFJORD(ApproximateContinuousBijection): Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape)) + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs): + nn_kwargs = nn_kwargs or {} + diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape, **nn_kwargs)) super().__init__(event_shape, diff_eq, **kwargs) @@ -29,8 +30,9 @@ class ConvolutionalFFJORD(ApproximateContinuousBijection): Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs): + nn_kwargs = nn_kwargs or {} if len(event_shape) != 3: raise ValueError("Event shape must be of length 3 (channels, height, width).") - diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0])) + diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0], **nn_kwargs)) super().__init__(event_shape, diff_eq, **kwargs) diff --git a/torchflows/bijections/continuous/otflow.py b/torchflows/bijections/continuous/otflow.py index fed00f5..d9817bf 100644 --- a/torchflows/bijections/continuous/otflow.py +++ b/torchflows/bijections/continuous/otflow.py @@ -141,8 +141,9 @@ def hessian_trace(self, class OTPotential(TimeDerivative): - def __init__(self, event_size: int, hidden_size: int = None, **kwargs): + def __init__(self, event_size: int, hidden_size: int = 50, resnet_kwargs: dict = None): super().__init__() + resnet_kwargs = resnet_kwargs or {} # hidden_size = m if hidden_size is None: @@ -163,7 +164,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs): self.w = nn.Parameter(1 + delta_w) self.A = nn.Parameter(torch.eye(r, event_size + 1) + delta_A) self.b = nn.Parameter(0 + delta_b) - self.resnet = OTResNet(event_size + 1, hidden_size, **kwargs) # (x, t) has d+1 elements + self.resnet = OTResNet(event_size + 1, hidden_size, **resnet_kwargs) # (x, t) has d+1 elements def forward(self, t, x): return self.gradient(concatenate_x_t(x, t)) @@ -208,7 +209,13 @@ class OTFlow(ExactContinuousBijection): Reference: Onken et al. "OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport" (2021); https://arxiv.org/abs/2006.00104. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], solver='dopri8', **kwargs): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + ot_flow_kwargs: dict = None, + solver='dopri8', + **kwargs): + ot_flow_kwargs = ot_flow_kwargs or {} + n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = OTFlowODEFunction(n_dim, hidden_size=50) + diff_eq = OTFlowODEFunction(n_dim, **ot_flow_kwargs) super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/rnode.py b/torchflows/bijections/continuous/rnode.py index f1213ac..0612a22 100644 --- a/torchflows/bijections/continuous/rnode.py +++ b/torchflows/bijections/continuous/rnode.py @@ -18,9 +18,14 @@ class RNODE(ApproximateContinuousBijection): Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape, hidden_size=100, n_hidden_layers=1), - regularization="sq_jac_norm") + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs): + default_nn_kwargs = {'hidden_size': 100, 'n_hidden_layers': 1} + nn_kwargs = nn_kwargs or dict() + default_nn_kwargs.update(nn_kwargs) + diff_eq = RegularizedApproximateODEFunction( + create_nn(event_shape, **default_nn_kwargs), + regularization="sq_jac_norm" + ) super().__init__(event_shape, diff_eq, **kwargs) @@ -30,8 +35,14 @@ class ConvolutionalRNODE(ApproximateContinuousBijection): Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs): + default_nn_kwargs = {'n_layers': 2} + nn_kwargs = nn_kwargs or dict() + default_nn_kwargs.update(nn_kwargs) if len(event_shape) != 3: raise ValueError("Event shape must be of length 3 (channels, height, width).") - diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0]), regularization="sq_jac_norm") + diff_eq = RegularizedApproximateODEFunction( + create_cnn(event_shape[0], **default_nn_kwargs), + regularization="sq_jac_norm" + ) super().__init__(event_shape, diff_eq, **kwargs) diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py index 30843d4..c82c762 100644 --- a/torchflows/bijections/finite/residual/iterative.py +++ b/torchflows/bijections/finite/residual/iterative.py @@ -72,23 +72,23 @@ def forward(self, x): class SpectralNeuralNetwork(nn.Module): def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], - n_hidden: int = None, + hidden_size: int = None, n_hidden_layers: int = 1, **kwargs): self.event_shape = event_shape event_size = int(torch.prod(torch.as_tensor(event_shape))) - if n_hidden is None: - n_hidden = int(3 * max(math.log(event_size), 4)) + if hidden_size is None: + hidden_size = int(3 * max(math.log(event_size), 4)) if n_hidden_layers == 0: layers = [SpectralLinear(event_size, event_size, **kwargs)] else: - layers = [SpectralLinear(event_size, n_hidden, **kwargs)] - for _ in range(n_hidden): + layers = [SpectralLinear(event_size, hidden_size, **kwargs)] + for _ in range(hidden_size): layers.append(nn.Tanh()) - layers.append(SpectralLinear(n_hidden, n_hidden, **kwargs)) + layers.append(SpectralLinear(hidden_size, hidden_size, **kwargs)) layers.pop(-1) - layers.append(SpectralLinear(n_hidden, event_size, **kwargs)) + layers.append(SpectralLinear(hidden_size, event_size, **kwargs)) super().__init__() self.layers = nn.ModuleList(layers) From 910adaf579bbf21264c3348e6ec08600164356f2 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 19 Oct 2024 16:41:44 +0200 Subject: [PATCH 23/25] Update docs --- docs/source/api/base_distributions.rst | 14 -- docs/source/api/bijections.rst | 24 --- docs/source/api/components.rst | 7 - docs/source/api/multiscale_architectures.rst | 11 -- .../general_modeling.rst} | 54 ++++- docs/source/architectures/image_modeling.rst | 56 ++++++ docs/source/architectures/index.rst | 187 ++++++++++++++++++ .../base_distributions.rst | 61 ++++++ .../bijections/autoregressive/bijections.rst | 103 ++++++++++ .../autoregressive/conditioner_transforms.rst | 21 ++ .../autoregressive/conditioners.rst | 8 + .../autoregressive/transformers.rst | 30 +++ .../bijections/continuous/bijections.rst | 5 + .../developer_reference/bijections/index.rst | 35 ++++ .../bijections/residual/bijections.rst | 2 + .../{api => developer_reference}/flow.rst | 2 +- docs/source/developer_reference/index.rst | 25 +++ docs/source/guides/image_modeling.rst | 26 ++- .../source/guides/mathematical_background.rst | 6 +- .../source/guides/modifying_architectures.rst | 154 +++++++++++++++ docs/source/guides/tutorial.rst | 1 + docs/source/index.rst | 19 +- 22 files changed, 776 insertions(+), 75 deletions(-) delete mode 100644 docs/source/api/base_distributions.rst delete mode 100644 docs/source/api/bijections.rst delete mode 100644 docs/source/api/components.rst delete mode 100644 docs/source/api/multiscale_architectures.rst rename docs/source/{api/architectures.rst => architectures/general_modeling.rst} (63%) create mode 100644 docs/source/architectures/image_modeling.rst create mode 100644 docs/source/architectures/index.rst create mode 100644 docs/source/developer_reference/base_distributions.rst create mode 100644 docs/source/developer_reference/bijections/autoregressive/bijections.rst create mode 100644 docs/source/developer_reference/bijections/autoregressive/conditioner_transforms.rst create mode 100644 docs/source/developer_reference/bijections/autoregressive/conditioners.rst create mode 100644 docs/source/developer_reference/bijections/autoregressive/transformers.rst create mode 100644 docs/source/developer_reference/bijections/continuous/bijections.rst create mode 100644 docs/source/developer_reference/bijections/index.rst create mode 100644 docs/source/developer_reference/bijections/residual/bijections.rst rename docs/source/{api => developer_reference}/flow.rst (96%) create mode 100644 docs/source/developer_reference/index.rst create mode 100644 docs/source/guides/modifying_architectures.rst diff --git a/docs/source/api/base_distributions.rst b/docs/source/api/base_distributions.rst deleted file mode 100644 index 174ba1e..0000000 --- a/docs/source/api/base_distributions.rst +++ /dev/null @@ -1,14 +0,0 @@ -Base distribution objects -========================== - -.. autoclass:: torchflows.base_distributions.gaussian.DiagonalGaussian - :members: __init__ - -.. autoclass:: torchflows.base_distributions.gaussian.DenseGaussian - :members: __init__ - -.. autoclass:: torchflows.base_distributions.mixture.DiagonalGaussianMixture - :members: __init__ - -.. autoclass:: torchflows.base_distributions.mixture.DenseGaussianMixture - :members: __init__ diff --git a/docs/source/api/bijections.rst b/docs/source/api/bijections.rst deleted file mode 100644 index 2be7d28..0000000 --- a/docs/source/api/bijections.rst +++ /dev/null @@ -1,24 +0,0 @@ -Bijection objects -==================== - -All normalizing flow transformations are bijections. -The following classes define forward and inverse pass methods which all flow architectures inherit. - -.. autoclass:: torchflows.bijections.base.Bijection - :members: __init__, forward, inverse - -.. autoclass:: torchflows.bijections.base.BijectiveComposition - :members: __init__ - -.. autoclass:: torchflows.bijections.continuous.base.ContinuousBijection - :members: __init__, forward, inverse - -.. autoclass:: torchflows.bijections.finite.multiscale.base.MultiscaleBijection - :members: __init__ - -Inverting a bijection ---------------------- - -Each bijection can be inverted with the `invert` function. - -.. autofunction:: torchflows.bijections.base.invert \ No newline at end of file diff --git a/docs/source/api/components.rst b/docs/source/api/components.rst deleted file mode 100644 index a84b33e..0000000 --- a/docs/source/api/components.rst +++ /dev/null @@ -1,7 +0,0 @@ -Model components -=================== - -.. toctree:: - base_distributions - bijections - flow diff --git a/docs/source/api/multiscale_architectures.rst b/docs/source/api/multiscale_architectures.rst deleted file mode 100644 index b74d185..0000000 --- a/docs/source/api/multiscale_architectures.rst +++ /dev/null @@ -1,11 +0,0 @@ -Multiscale architectures -======================================================== - -Multiscale architectures are suitable for image modeling. - -.. _multiscale_architectures: - -.. autoclass:: torchflows.architectures.MultiscaleRealNVP -.. autoclass:: torchflows.architectures.MultiscaleRQNSF -.. autoclass:: torchflows.architectures.MultiscaleLRSNSF -.. autoclass:: torchflows.architectures.MultiscaleNICE diff --git a/docs/source/api/architectures.rst b/docs/source/architectures/general_modeling.rst similarity index 63% rename from docs/source/api/architectures.rst rename to docs/source/architectures/general_modeling.rst index 73b9613..b1274f3 100644 --- a/docs/source/api/architectures.rst +++ b/docs/source/architectures/general_modeling.rst @@ -1,38 +1,86 @@ -Standard architectures +API for standard architectures ============================ We lists notable implemented bijection architectures. These all inherit from the Bijection class. -.. _architectures: +.. _autoregressive_architecture_api: Autoregressive architectures -------------------------------- .. autoclass:: torchflows.architectures.RealNVP + :members: __init__ + .. autoclass:: torchflows.architectures.InverseRealNVP + :members: __init__ + .. autoclass:: torchflows.architectures.NICE + :members: __init__ + .. autoclass:: torchflows.architectures.MAF + :members: __init__ + .. autoclass:: torchflows.architectures.IAF + :members: __init__ + .. autoclass:: torchflows.architectures.CouplingRQNSF + :members: __init__ + .. autoclass:: torchflows.architectures.MaskedAutoregressiveRQNSF + :members: __init__ + .. autoclass:: torchflows.architectures.InverseAutoregressiveRQNSF + :members: __init__ + .. autoclass:: torchflows.architectures.CouplingLRS + :members: __init__ + .. autoclass:: torchflows.architectures.MaskedAutoregressiveLRS + :members: __init__ + +.. autoclass:: torchflows.architectures.InverseAutoregressiveLRS + :members: __init__ + .. autoclass:: torchflows.architectures.CouplingDSF + :members: __init__ + .. autoclass:: torchflows.architectures.UMNNMAF + :members: __init__ + +.. _continuous_architecture_api: Continuous architectures ------------------------- .. autoclass:: torchflows.architectures.DeepDiffeomorphicBijection + :members: __init__ + .. autoclass:: torchflows.architectures.RNODE + :members: __init__ + .. autoclass:: torchflows.architectures.FFJORD + :members: __init__ + .. autoclass:: torchflows.architectures.OTFlow + :members: __init__ + +.. _residual_architecture_api: Residual architectures ----------------------- .. autoclass:: torchflows.architectures.ResFlow + :members: __init__ + .. autoclass:: torchflows.architectures.ProximalResFlow + :members: __init__ + .. autoclass:: torchflows.architectures.InvertibleResNet + :members: __init__ + .. autoclass:: torchflows.architectures.PlanarFlow + :members: __init__ + .. autoclass:: torchflows.architectures.RadialFlow -.. autoclass:: torchflows.architectures.SylvesterFlow \ No newline at end of file + :members: __init__ + +.. autoclass:: torchflows.architectures.SylvesterFlow + :members: __init__ diff --git a/docs/source/architectures/image_modeling.rst b/docs/source/architectures/image_modeling.rst new file mode 100644 index 0000000..8609a27 --- /dev/null +++ b/docs/source/architectures/image_modeling.rst @@ -0,0 +1,56 @@ +API for multiscale architectures +======================================================== + +Multiscale architectures are suitable for image modeling. + +.. _multiscale_architecture_api: + + +Classic multiscale architectures +------------------------------ + +.. autoclass:: torchflows.architectures.MultiscaleNICE + :members: __init__ + +.. autoclass:: torchflows.architectures.MultiscaleRealNVP + :members: __init__ + +.. autoclass:: torchflows.architectures.MultiscaleRQNSF + :members: __init__ + +.. autoclass:: torchflows.architectures.MultiscaleLRSNSF + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.MultiscaleDeepSigmoid + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.MultiscaleDenseSigmoid + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.MultiscaleDeepDenseSigmoid + :members: __init__ + + +Glow-style multiscale architectures +------------------------------ + +.. autoclass:: torchflows.architectures.AffineGlow + :members: __init__ + +.. autoclass:: torchflows.architectures.ShiftGlow + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.RQSGlow + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.LRSGlow + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.DeepSigmoidGlow + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.DenseSigmoidGlow + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.multiscale.architectures.DeepDenseSigmoidGlow + :members: __init__ diff --git a/docs/source/architectures/index.rst b/docs/source/architectures/index.rst new file mode 100644 index 0000000..61a09ee --- /dev/null +++ b/docs/source/architectures/index.rst @@ -0,0 +1,187 @@ +Full list of architectures (presets) +===================================================== + +We list all implemented NF architectures and their respective class names below. +Using these presets facilitates experimentation and modeling, however you can also modify each architecture and build new ones. + +.. _autoregressive_architecture_list: + +Autoregressive architectures +----------------------------- + +We provide the list of autoregressive architectures in the table below. +Click the architecture name to see the API and usage examples. +Check the API for all autoregressive architectures :ref:`here `. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`NICE ` + - Dinh et al. `NICE: Non-linear Independent Components Estimation `_ (2015) + * - :class:`RealNVP ` + - Dinh et al. `Density estimation using Real NVP `_ (2017) + * - :class:`Inverse RealNVP ` + - Dinh et al. `Density estimation using Real NVP `_ (2017) + * - :class:`MAF ` + - Papamakarios et al. `Masked Autoregressive Flow for Density Estimation `_ (2018) + * - :class:`IAF ` + - Kingma et al. `Improving Variational Inference with Inverse Autoregressive Flow `_ (2017) + * - :class:`Coupling RQ-NSF ` + - Durkan et al. `Neural Spline Flows `_ (2019) + * - :class:`Masked autoregressive RQ-NSF ` + - Durkan et al. `Neural Spline Flows `_ (2019) + * - :class:`Inverse autoregressive RQ-NSF ` + - Durkan et al. `Neural Spline Flows `_ (2019) + * - :class:`Coupling LR-NSF ` + - Dolatabadi et al. `Invertible Generative Modeling using Linear Rational Splines `_ (2020) + * - :class:`Masked autoregressive LR-NSF ` + - Dolatabadi et al. `Invertible Generative Modeling using Linear Rational Splines `_ (2020) + * - :class:`Inverse autoregressive LR-NSF ` + - Dolatabadi et al. `Invertible Generative Modeling using Linear Rational Splines `_ (2020) + * - :class:`Coupling deep SF ` + - + * - :class:`Masked autoregressive deep SF ` + - + * - :class:`Inverse autoregressive deep SF ` + - + * - :class:`Coupling dense SF ` + - + * - :class:`Masked autoregressive dense SF ` + - + * - :class:`Inverse autoregressive dense SF ` + - + * - :class:`Coupling deep-dense SF ` + - + * - :class:`Masked autoregressive deep-dense SF ` + - + * - :class:`Inverse autoregressive deep-dense SF ` + - + * - :class:`Unconstrained monotonic neural network ` + - + +.. _multiscale_architecture_list: + +Multiscale architectures +----------------------------------------- +We provide the list of multiscale autoregressive architectures in the table below. +These architectures are specifically made for image modeling, but can also be used for voxels or tensors with more dimensions. +Click the architecture name to see the API and usage examples. +Check the API for all multiscale architectures :ref:`here `. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`MultiscaleNICE ` + - Dinh et al. `NICE: Non-linear Independent Components Estimation `_ (2015) + * - :class:`Multiscale RealNVP ` + - Dinh et al. `Density estimation using Real NVP `_ (2017) + * - :class:`Multiscale RQ-NSF ` + - Durkan et al. `Neural Spline Flows `_ (2019) + * - :class:`Multiscale LR-NSF ` + - Dolatabadi et al. `Invertible Generative Modeling using Linear Rational Splines `_ (2020) + * - :class:`Multiscale deep SF ` + - + * - :class:`Multiscale dense SF ` + - + * - :class:`Multiscale deep-dense SF ` + - + * - :class:`Shift Glow ` + - + * - :class:`Affine Glow ` + - + * - :class:`RQS Glow ` + - + * - :class:`LRS Glow ` + - + * - :class:`Deep sigmoidal Glow ` + - + * - :class:`Dense sigmoidal Glow ` + - + * - :class:`Deep-dense sigmoidal Glow ` + - + +Residual architectures +---------------------------- +We provide the list of iterative residual architectures in the table below. +Click the architecture name to see the API and usage examples. +Check the API for all residual architectures :ref:`here `. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`Invertible ResNet ` + - + * - :class:`ResFlow ` + - + * - :class:`ProximalResFlow ` + - + +We also list presets for some convolutional iterative residual architectures in the table below. +These are suitable for image modeling. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`Convolutional invertible ResNet ` + - + * - :class:`Convolutional ResFlow ` + - + +We finally list presets for residual architectures, based on the matrix determinant lemma. +These support either forward or inverse transformation, but not both. +This means they can be used for either sampling (and variational inference) or density estimation (and maximum likelihood fits), but not both at the same time. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`Planar flow ` + - + * - :class:`Radial flow ` + - + * - :class:`Sylvester flow ` + - + +Continuous architectures +---------------------------- +We provide the list of continuous architectures in the table below. +Click the architecture name to see the API and usage examples. +Check the API for all continuous architectures :ref:`here `. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`DDNF ` + - + * - :class:`FFJORD ` + - + * - :class:`RNODE ` + - + * - :class:`OT-Flow ` + - + +We also list presets for convolutional continuous architectures in the table below. +These are suitable for image modeling. + +.. list-table:: + :header-rows: 1 + + * - Architecture + - Reference + * - :class:`Convolutional DDNF ` + - + * - :class:`Convolutional FFJORD ` + - + * - :class:`Convolutional RNODE ` + - \ No newline at end of file diff --git a/docs/source/developer_reference/base_distributions.rst b/docs/source/developer_reference/base_distributions.rst new file mode 100644 index 0000000..614b876 --- /dev/null +++ b/docs/source/developer_reference/base_distributions.rst @@ -0,0 +1,61 @@ +Base distributions +========================== + +Existing base distributions +----------------------------- + +.. autoclass:: torchflows.base_distributions.gaussian.DiagonalGaussian + :members: __init__ + +.. autoclass:: torchflows.base_distributions.gaussian.DenseGaussian + :members: __init__ + +.. autoclass:: torchflows.base_distributions.mixture.DiagonalGaussianMixture + :members: __init__ + +.. autoclass:: torchflows.base_distributions.mixture.DenseGaussianMixture + :members: __init__ + +Creating new base distributions +----------------------------------- + +To create a new base distribution, we must create a subclass of :class:`torch.distributions.Distribution` and :class:`torch.nn.Module`. +This class should support the methods sampling and log probability computation. +We give an example for the diagonal Gaussian base distribution: + +.. code-block:: python + + import torch + import torch.distributions + import torch.nn as nn + import math + + class DiagonalGaussian(torch.distributions.Distribution, nn.Module): + def __init__(self, loc: torch.Tensor, scale: torch.Tensor): + super().__init__(event_shape=loc.shape, validate_args=False) + self.log_2_pi = math.log(2 * math.pi) + self.register_buffer('loc', loc) + self.register_buffer('log_scale', torch.log(scale)) + + @property + def scale(self): + return torch.exp(self.log_scale) + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + noise = torch.randn(size=(*sample_shape, *self.event_shape)).to(self.loc) + # Unsqueeze loc and scale to match batch shape + sample_shape_mask = [None for _ in range(len(sample_shape))] + return self.loc[sample_shape_mask] + noise * self.scale[sample_shape_mask] + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + if len(value.shape) <= len(self.event_shape): + raise ValueError("Incorrect input shape") + # Unsqueeze loc and scale to match batch shape + sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] + loc = self.loc[sample_shape_mask] + scale = self.scale[sample_shape_mask] + log_scale = self.log_scale[sample_shape_mask] + + # Compute log probability + elementwise_log_prob = -(0.5 * ((value - loc) / scale) ** 2 + 0.5 * self.log_2_pi + log_scale) + return sum_except_batch(elementwise_log_prob, self.event_shape) \ No newline at end of file diff --git a/docs/source/developer_reference/bijections/autoregressive/bijections.rst b/docs/source/developer_reference/bijections/autoregressive/bijections.rst new file mode 100644 index 0000000..d793793 --- /dev/null +++ b/docs/source/developer_reference/bijections/autoregressive/bijections.rst @@ -0,0 +1,103 @@ +Autoregressive bijections +=========================== + +Autoregressive bijections belong in one of two categories: coupling or masked autoregressive bijections. +Architectures like IAF make use of the inverse masked autoregressive bijection, which simply swaps the `forward` and `inverse` methods of its corresponding masked autoregressive counterpart. +Multiscale architectures are special cases of coupling architectures. +Each autoregressive bijection consists of a transformer (parameterized bijection that transforms a part of the input), a conditioner, and a conditioner transform (model that predicts transformer parameters). +See the :doc:`transformers` and :doc:`conditioner_transforms` sections for more details. +To improve performance, we define subclasses according to the conditioner type. +We list these subclasses in the rest of the document. + +Coupling bijections +-------------------------------------------------- + +Coupling architectures are compositions of coupling bijections, which extend the following base class: + +.. autoclass:: torchflows.bijections.finite.autoregressive.layers_base.CouplingBijection + :members: __init__ + +We give an example on how to create a custom coupling bijection using a transformer, coupling strategy, and conditioner transform: + +.. code-block:: python + + from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection + from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine + from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit + from torchflows.bijections.finite.autoregressive.conditioning.transforms import ResidualFeedForward + + class AffineCoupling(CouplingBijection): + def __init__(self, event_shape, **kwargs): + coupling = HalfSplit(event_shape) + super().__init__( + event_shape, + transformer_class=Affine, + coupling=HalfSplit(event_shape), + conditioner_transform_class=ResidualFeedForward + ) + + event_shape = (10,) # say we have vectors of size 10 + bijection = AffineCoupling(event_shape) # create the bijection + +Masked autoregressive bijections +---------------------------------------- + +Masked autoregressive and inverse autoregressive architectures are compositions of their respective bijections, extending one of the following classes: + +.. autoclass:: torchflows.bijections.finite.autoregressive.layers_base.MaskedAutoregressiveBijection + :members: __init__ + +.. autoclass:: torchflows.bijections.finite.autoregressive.layers_base.InverseMaskedAutoregressiveBijection + :members: __init__ + +We give an example on how to create a custom coupling bijection using a transformer, coupling strategy, and conditioner transform: + +.. code-block:: python + + from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection + from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine + from torchflows.bijections.finite.autoregressive.conditioning.transforms import ResidualFeedForward + + class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, event_shape, **kwargs): + super().__init__( + event_shape, + transformer_class=Affine, + conditioner_transform_class=ResidualFeedForward + ) + + class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, event_shape, **kwargs): + super().__init__( + event_shape, + transformer_class=Affine, + conditioner_transform_class=ResidualFeedForward + ) + + + # say we have 100 vectors of size 10 + event_shape = (10,) + x = torch.randn(size=(100, *event_shape)) + + bijection = AffineCoupling(event_shape) # create the bijection + z, log_det_forward = bijection.forward(x) + y, log_det_inverse = bijection.inverse(z) + + +Multiscale autoregressive bijections +-------------------------------------------------- + +Multiscale architectures are coupling architectures which are specialized for image modeling, extending the class below: + +.. autoclass:: torchflows.bijections.finite.multiscale.base.MultiscaleBijection + :members: __init__ + +See also +------------ + +.. toctree:: + :maxdepth: 1 + + transformers + conditioners + conditioner_transforms \ No newline at end of file diff --git a/docs/source/developer_reference/bijections/autoregressive/conditioner_transforms.rst b/docs/source/developer_reference/bijections/autoregressive/conditioner_transforms.rst new file mode 100644 index 0000000..98873fc --- /dev/null +++ b/docs/source/developer_reference/bijections/autoregressive/conditioner_transforms.rst @@ -0,0 +1,21 @@ +List of conditioner transforms +============================== + +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.FeedForward +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.Linear +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.ResidualFeedForward +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.Constant +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.MADE +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.LinearMADE + +Conditioner combinations +-------------------------- +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.CombinedConditioner +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.RegularizedCombinedConditioner +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.RegularizedGraphicalConditioner + +Base classes +--------------- +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.ConditionerTransform +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.ElementwiseConditionerTransform +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.transforms.TensorConditionerTransform diff --git a/docs/source/developer_reference/bijections/autoregressive/conditioners.rst b/docs/source/developer_reference/bijections/autoregressive/conditioners.rst new file mode 100644 index 0000000..24f5e73 --- /dev/null +++ b/docs/source/developer_reference/bijections/autoregressive/conditioners.rst @@ -0,0 +1,8 @@ +List of conditioners +======================================== + +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.coupling_masks.PartialCoupling +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.coupling_masks.Coupling +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.coupling_masks.GraphicalCoupling +.. autoclass:: torchflows.bijections.finite.autoregressive.conditioning.coupling_masks.HalfSplit +.. autofunction:: torchflows.bijections.finite.autoregressive.conditioning.coupling_masks.make_coupling diff --git a/docs/source/developer_reference/bijections/autoregressive/transformers.rst b/docs/source/developer_reference/bijections/autoregressive/transformers.rst new file mode 100644 index 0000000..a95a5f7 --- /dev/null +++ b/docs/source/developer_reference/bijections/autoregressive/transformers.rst @@ -0,0 +1,30 @@ +List of transformers +================================ + +Torchflows supports several transformers to be used in autoregressive and multiscale normalizing flows. + +Linear transformers +-------------------- +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.linear.affine.Affine +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.linear.affine.InverseAffine +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.linear.affine.Shift +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.linear.convolution.Invertible1x1ConvolutionTransformer +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.linear.matrix.LUTransformer + +Spline transformers +-------------------------------- +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.spline.linear.Linear +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational.LinearRational +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic.RationalQuadratic + +Combination transformers +--------------------------------------- + +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid.Sigmoid +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid.DenseSigmoid +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid.DeepSigmoid +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid.DeepDenseSigmoid + +Integration transformers +--------------------------------- +.. autoclass:: torchflows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network.UnconstrainedMonotonicNeuralNetwork diff --git a/docs/source/developer_reference/bijections/continuous/bijections.rst b/docs/source/developer_reference/bijections/continuous/bijections.rst new file mode 100644 index 0000000..3323007 --- /dev/null +++ b/docs/source/developer_reference/bijections/continuous/bijections.rst @@ -0,0 +1,5 @@ +Continuous bijections +=========================== + +.. autoclass:: torchflows.bijections.continuous.base.ContinuousBijection + :members: __init__, forward, inverse \ No newline at end of file diff --git a/docs/source/developer_reference/bijections/index.rst b/docs/source/developer_reference/bijections/index.rst new file mode 100644 index 0000000..f5cffd5 --- /dev/null +++ b/docs/source/developer_reference/bijections/index.rst @@ -0,0 +1,35 @@ +Bijections +==================== + +All normalizing flow transformations are bijections and compositions thereof. + +Base bijections +------------------ + +The following classes define forward and inverse pass methods which all bijections inherit. + +.. autoclass:: torchflows.bijections.base.Bijection + :members: __init__, forward, inverse + +.. autoclass:: torchflows.bijections.base.BijectiveComposition + :members: __init__ + + +Bijection subclasses for different NF families +------------------------------------------------------------------ +To improve efficiency of forward and inverse passes in NF layers, we subclass the base bijections with respect to different families of NF architectures. +On the pages below, we list base classes for each family, and provide a list of already implemented classes. + +.. toctree:: + :maxdepth: 1 + + autoregressive/bijections + residual/bijections + continuous/bijections + +Inverting a bijection +------------------------------ + +Each bijection can be inverted with the `invert` function. + +.. autofunction:: torchflows.bijections.base.invert \ No newline at end of file diff --git a/docs/source/developer_reference/bijections/residual/bijections.rst b/docs/source/developer_reference/bijections/residual/bijections.rst new file mode 100644 index 0000000..22e7aa1 --- /dev/null +++ b/docs/source/developer_reference/bijections/residual/bijections.rst @@ -0,0 +1,2 @@ +Residual bijections +=========================== diff --git a/docs/source/api/flow.rst b/docs/source/developer_reference/flow.rst similarity index 96% rename from docs/source/api/flow.rst rename to docs/source/developer_reference/flow.rst index 0fe1d1d..f4cdfff 100644 --- a/docs/source/api/flow.rst +++ b/docs/source/developer_reference/flow.rst @@ -1,4 +1,4 @@ -Flow objects +Flow wrappers =============================== The `Flow` object contains a base distribution and a bijection. diff --git a/docs/source/developer_reference/index.rst b/docs/source/developer_reference/index.rst new file mode 100644 index 0000000..729d00d --- /dev/null +++ b/docs/source/developer_reference/index.rst @@ -0,0 +1,25 @@ +Developer reference +=========================== + +This section describes how to create NF architectures and NF components in Torchflows. +NFs consist of two main components: + +* a base distribution, +* a bijection. + +In Torchflows, we further wrap these two with the :class:`torchflows.flows.Flow` object or one of its subclasses to enable e.g., fitting NFs, computing the log probability density, and sampling. + +At its core, each of these components is a PyTorch module which extends existing base classes: + +* :class:`torch.distributions.Distribution` and :class:`torch.nn.Module` for base distributions, +* :class:`torchflows.bijections.base.Bijection` for bijections, +* :class:`torchflows.flows.BaseFlow` for flow wrappers. + +Check the following pages for existing subclasses and to learn to create new subclasses for your modeling and research needs: + +.. toctree:: + :maxdepth: 1 + + base_distributions + bijections/index + flow diff --git a/docs/source/guides/image_modeling.rst b/docs/source/guides/image_modeling.rst index 757365e..134d8d5 100644 --- a/docs/source/guides/image_modeling.rst +++ b/docs/source/guides/image_modeling.rst @@ -3,9 +3,11 @@ Image modeling When modeling images, we can use specialized multiscale architectures which use convolutional neural network conditioners and specialized coupling schemes. These architectures expect event shapes to be *(channels, height, width)*. +See the :ref:`list of multiscale architecture presets here `. -.. note:: - Multiscale architectures are currently undergoing improvements. +Basic multiscale architectures +--------------------------------------- +We provide some basic multiscale presets and give an example for the RealNVP variant below: .. code-block:: python @@ -19,4 +21,24 @@ These architectures expect event shapes to be *(channels, height, width)*. torch.manual_seed(0) training_images = torch.randn(size=(n_images, *image_shape)) # synthetic data flow = Flow(MultiscaleRealNVP(image_shape)) + flow.fit(training_images, show_progress=True) + +Glow-style multiscale architectures +------------------------------------------- + +Glow-style architectures are extensions of basic multiscale architectures which use an additional invertible 1x1 convolution in each layer. +We give an example for Glow with affine transformers below: + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import AffineGlow + + image_shape = (3, 28, 28) + n_images = 100 + + torch.manual_seed(0) + training_images = torch.randn(size=(n_images, *image_shape)) # synthetic data + flow = Flow(AffineGlow(image_shape)) flow.fit(training_images, show_progress=True) \ No newline at end of file diff --git a/docs/source/guides/mathematical_background.rst b/docs/source/guides/mathematical_background.rst index cb4dd94..d1807c7 100644 --- a/docs/source/guides/mathematical_background.rst +++ b/docs/source/guides/mathematical_background.rst @@ -1,5 +1,9 @@ +Mathematical background +============================= + + What is a normalizing flow -========================== +-------------------------------------------- A normalizing flow (NF) is a flexible trainable distribution. It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian. diff --git a/docs/source/guides/modifying_architectures.rst b/docs/source/guides/modifying_architectures.rst new file mode 100644 index 0000000..f1f2249 --- /dev/null +++ b/docs/source/guides/modifying_architectures.rst @@ -0,0 +1,154 @@ +Modifying normalizing flow architectures +============================================ + +We sometimes wish to experiment with bijection parameters to improve NF performance on a given dataset. +We give a few examples on how to achieve this with Torchflows. + +Passing hyperparameters to existing architecture constructors +------------------------------------------------------------------- +We can make basic modifications to an existing NF architecture by passing it certain keyword arguments. +The permitted keyword arguments depend on the architecture. +Suppose we are working with RealNVP, which is a composition of several affine coupling layers. +We wish our RealNVP instance to have 5 affine coupling layers. +Each affine coupling layer should use a feed-forward neural network conditioner with 5 layers, as well as 10 hidden neurons and the ReLU activation in each layer. + +.. code-block:: python + + import torch.nn as nn + from torchflows.flows import Flow + from torchflows.architectures import RealNVP + from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward + + event_shape = (10,) + custom_hyperparameters = { + 'n_layers': 5, + 'conditioner_transform_class': FeedForward, + 'conditioner_kwargs': { + 'n_layers': 5, + 'n_hidden': 10, + 'nonlinearity': nn.ReLU + } + } + bijection = RealNVP(event_shape, **custom_hyperparameters) + flow = Flow(bijection) + +`Autoregressive architectures `_ can receive hyperparameters through the following keyword arguments: + +* ``n_layers``: the number of affine coupling layer; +* ``conditioner_transform_class``: the conditioner type to use in each layer; +* ``conditioner_kwargs``: conditioner keyword arguments for each layer; +* ``transformer_kwargs``: transformer keyword arguments for each layer. + +The specific keyword arguments depend on which conditioner and transformer we are using. +Check the list of implemented conditioner transforms and their constructors :doc:`here <../developer_reference/bijections/autoregressive/conditioner_transforms>`. +See which transformers are used in each architecture :ref:`here `. + +Coupling architectures can also receive: + +* ``edge_list``: an edge list of conditional dimension interactions; +* ``coupling_kwargs``: keyword arguments for :func:`make_coupling` in each layer. + +To see how other architectures use keyword arguments, consider checking the :doc:`list of architectures <../architectures/index>` + +Composing existing bijections with custom hyperparameters +------------------------------------------------------------- +In the previous section, we learned how to modify a preset architecture by passing some hyperparameters. +In residual and autoregressive NFs, this approach will use the same hyperparameters for each layer of the NF. +For more customization, we can create individual layers and compose them into a custom architecture. +Suppose we wish to create a NF with five layers: + +* two affine coupling layers, +* a rational quadratic spline coupling layer, +* an invertible residual network layer, +* an elementwise shift layer. + +The above model can be coded as follows: + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.bijections.base import BijectiveComposition + from torchflows.bijections.finite.autoregressive.layers import AffineCoupling, RQSCoupling, ElementwiseShift + from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock + + torch.manual_seed(0) + event_shape = (10,) + bijection = BijectiveComposition( + event_shape, + [ + AffineCoupling(event_shape), + AffineCoupling(event_shape), + RQSCoupling(event_shape), + InvertibleResNetBlock(event_shape), + ElementwiseShift(event_shape), + ] + ) + flow = Flow(bijection) + + x_new = flow.sample((10,)) + log_prob = flow.log_prob(x_new) + +We can also customize each layer with custom hyperparameters, for example: + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.bijections.base import BijectiveComposition + from torchflows.bijections.finite.autoregressive.layers import AffineCoupling, RQSCoupling, ElementwiseShift + from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock + + torch.manual_seed(0) + event_shape = (10,) + bijection = BijectiveComposition( + event_shape, + [ + AffineCoupling(event_shape, conditioner_kwargs={'n_hidden': 5, 'n_layers': 10}), + AffineCoupling(event_shape), + RQSCoupling(event_shape, conditioner_kwargs={'n_layers': 1}), + InvertibleResNetBlock(event_shape, hidden_size=4, n_hidden_layers=3), + ElementwiseShift(event_shape), + ] + ) + flow = Flow(bijection) + + x_new = flow.sample((10,)) + log_prob = flow.log_prob(x_new) + +.. note:: + + Due to the large number of bijections in the library, argument names are not always consistent across bijections. + Check bijection constructors to make sure you are using correct argument names. + We are working to improve this in a future release. + +Composing NF architectures +---------------------------------------- + +Since each NF transformation is a bijection, we can compose them as any other. +We give an example below, where we compose RealNVP, coupling RQ-NSF, FFJORD, and ResFlow: + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.bijections.base import BijectiveComposition + from torchflows.bijections.finite.autoregressive.architectures import RealNVP, CouplingRQNSF + from torchflows.bijections.finite.residual.architectures import ResFlow + from torchflows.bijections.continuous.ffjord import FFJORD + + torch.manual_seed(0) + event_shape = (10,) + bijection = BijectiveComposition( + event_shape, + [ + RealNVP(event_shape), + CouplingRQNSF(event_shape), + FFJORD(event_shape), + ResFlow(event_shape) + ] + ) + flow = Flow(bijection) + + x_new = flow.sample((10,)) + log_prob = flow.log_prob(x_new) diff --git a/docs/source/guides/tutorial.rst b/docs/source/guides/tutorial.rst index f9f63b1..a3fd141 100644 --- a/docs/source/guides/tutorial.rst +++ b/docs/source/guides/tutorial.rst @@ -11,3 +11,4 @@ We provide tutorials and notebooks for typical Torchflows use cases. image_modeling choosing_base_distributions cuda + modifying_architectures diff --git a/docs/source/index.rst b/docs/source/index.rst index d2f9b46..eb2210f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -25,21 +25,16 @@ Torchflows can be installed easily using pip: For other install options, see the :ref:`install ` section. -Guides -========= +Table of contents +---------------------------- .. toctree:: + :maxdepth: 2 guides/installing guides/tutorial - -API -==== - -.. toctree:: - :maxdepth: 3 - - api/components - api/architectures - api/multiscale_architectures + architectures/index + architectures/general_modeling + architectures/image_modeling + developer_reference/index From 5aefcd66104699fcb79ea252edbc4e3b24af4f09 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 19 Oct 2024 16:41:53 +0200 Subject: [PATCH 24/25] Update docstrings --- .../finite/autoregressive/architectures.py | 2 +- .../finite/autoregressive/layers_base.py | 30 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 8975685..97cb0da 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -78,7 +78,7 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwarg """ :param event_shape: shape of the event tensor. - :param kwargs: keyword arguments to AffineCoupling. + :param kwargs: keyword arguments to :class:`~bijections.finite.autoregressive.layers.AffineCoupling`. """ super().__init__(event_shape, base_bijection=AffineCoupling, **kwargs) diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index 9a47270..0c4afcd 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -39,18 +39,15 @@ class CouplingBijection(AutoregressiveBijection): """ Base coupling bijection object. - A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner. + A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner (specifying how to partition the input tensor). - The coupling conditioner receives as input an event tensor x. - It then partitions an input event tensor x into a constant part x_A and a modifiable part x_B. - For x_A, the conditioner outputs a set of parameters which is always the same. - For x_B, the conditioner outputs a set of parameters which are predicted from x_A. + The coupling conditioner receives as input an event tensor :math:`x`. + It then partitions an input event tensor x into a constant part :math:`x_A` and a modifiable part :math:`x_B`. + For :math:`x_A`, the conditioner outputs a set of parameters which is always the same. + For :math:`x_B`, the conditioner outputs a set of parameters which are predicted from :math:`x_A`. + Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is :math:`x_A` and the second half is :math:`x_B`. When using this in a normalizing flow, permutation layers can shuffle event dimensions. - Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is x_A - and the second half is x_B. When using this in a normalizing flow, permutation layers can shuffle event dimensions. - - For improved performance, this implementation does not use a standalone coupling conditioner. It instead implements - a method to partition x into x_A and x_B and then predict parameters for x_B. + For improved performance, this implementation does not use a standalone coupling conditioner, but implements a method to partition x into :math:`x_A` and :math:`x_B` and then predict parameters for :math:`x_B`. """ def __init__(self, @@ -63,6 +60,19 @@ def __init__(self, conditioner_kwargs: dict = None, transformer_kwargs: dict = None, **kwargs): + """ + CouplingBijection constructor. + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Type[TensorTransformer] transformer_class: transformer class. + :param Union[Tuple[int, ...], torch.Size] context_shape: + :param PartialCoupling coupling: + :param Type[ConditionerTransform] conditioner_transform_class: + :param Dict coupling_kwargs: + :param Dict conditioner_kwargs: + :param Dict transformer_kwargs: + :param kwargs: + """ coupling_kwargs = coupling_kwargs or {} conditioner_kwargs = conditioner_kwargs or {} transformer_kwargs = transformer_kwargs or {} From 38d29e1b37400a67b013ded56b14370a24a2edda Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 19 Oct 2024 16:47:55 +0200 Subject: [PATCH 25/25] Fix typehint syntaxt --- .../bijections/finite/autoregressive/conditioning/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py index e34d947..008a2cc 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py @@ -273,7 +273,7 @@ def __init__(self, *args, **kwargs): class FeedForward(TensorConditionerTransform): def __init__(self, input_event_shape: Union[torch.Size, Tuple[int, ...]], - parameter_shape: torch.Union[torch.Size, Tuple[int, ...]], + parameter_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, n_layers: int = 2,