diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index 64eb863..a05fe5a 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -81,7 +81,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor return outputs, log_dets -def invert(bijection: Bijection) -> Bijection: +def invert(bijection): """ Swap the forward and inverse methods of the input bijection. """ diff --git a/normalizing_flows/bijections/continuous/base.py b/normalizing_flows/bijections/continuous/base.py index 12caf68..6751ee5 100644 --- a/normalizing_flows/bijections/continuous/base.py +++ b/normalizing_flows/bijections/continuous/base.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from torchdiffeq import odeint from normalizing_flows.bijections.base import Bijection from normalizing_flows.bijections.continuous.layers import DiffEqLayer @@ -306,6 +305,10 @@ def inverse(self, :param kwargs: :return: """ + + # Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed + from torchdiffeq import odeint + # Flatten everything to facilitate computations batch_shape = get_batch_shape(z, self.event_shape) batch_size = int(torch.prod(torch.as_tensor(batch_shape))) @@ -399,6 +402,9 @@ def inverse(self, :param kwargs: :return: """ + # Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed + from torchdiffeq import odeint + # Flatten everything to facilitate computations batch_shape = get_batch_shape(z, self.event_shape) batch_size = int(torch.prod(torch.as_tensor(batch_shape))) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index 0e1331a..cca7e34 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -11,7 +11,7 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive + LRSForwardMaskedAutoregressive, ElementwiseShift ) from normalizing_flows.bijections.base import BijectiveComposition from normalizing_flows.bijections.finite.linear import ReversePermutation @@ -127,13 +127,13 @@ class CouplingLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), LRSCoupling(event_shape=event_shape) ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -141,13 +141,13 @@ class MaskedAutoregressiveLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), LRSForwardMaskedAutoregressive(event_shape=event_shape) ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -173,7 +173,7 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape) + DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 1bf15e5..608bcc5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -1,4 +1,5 @@ import math +from typing import Tuple, Union, Type, List import torch import torch.nn as nn @@ -9,13 +10,41 @@ class ConditionerTransform(nn.Module): + """ + Module which predicts transformer parameters for the transformation of a tensor y using an input tensor x and + possibly a corresponding context tensor c. + + In other words, a conditioner transform f predicts theta = f(x, c) to be used in transformer g with z = g(y; theta). + The transformation g is performed elementwise on tensor y. + Since g transforms each element of y with a parameter tensor of shape (n_transformer_parameters,), + the shape of theta is (*y.shape, n_transformer_parameters). + """ + def __init__(self, input_event_shape, context_shape, - output_event_shape, - n_predicted_parameters: int, - context_combiner: ContextCombiner = None): + parameter_shape: Union[torch.Size, Tuple[int, ...]], + context_combiner: ContextCombiner = None, + global_parameter_mask: torch.Tensor = None, + initial_global_parameter_value: float = None): + """ + :param input_event_shape: shape of conditioner input tensor x. + :param context_shape: shape of conditioner context tensor c. + :param parameter_shape: shape of parameter tensor required to transform transformer input y. + :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. + :param global_parameter_mask: boolean tensor which determines which elements of parameter tensors should be + learned globally instead of predicted. If an element is set to 1, that element is learned globally. + We require that global_parameter_mask.shape = parameter_shape. + :param initial_global_parameter_value: initial global parameter value as a single scalar. If None, all initial + global parameters are independently drawn from the standard normal distribution. + """ super().__init__() + if global_parameter_mask is not None and global_parameter_mask.shape != parameter_shape: + raise ValueError( + f"Global parameter mask must have shape equal to the output parameter shape {parameter_shape}, " + f"but found {global_parameter_mask.shape}" + ) + if context_shape is None: context_combiner = Bypass(input_event_shape) elif context_shape is not None and context_combiner is None: @@ -24,41 +53,69 @@ def __init__(self, # The conditioner transform receives as input the context combiner output self.input_event_shape = input_event_shape - self.output_event_shape = output_event_shape self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims - self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) - self.n_predicted_parameters = n_predicted_parameters + + # Setup output parameter attributes + self.parameter_shape = parameter_shape + self.global_parameter_mask = global_parameter_mask + self.n_transformer_parameters = int(torch.prod(torch.as_tensor(self.parameter_shape))) + self.n_global_parameters = 0 if global_parameter_mask is None else int(torch.sum(self.global_parameter_mask)) + self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters + + if initial_global_parameter_value is None: + initial_global_theta_flat = torch.randn(size=(self.n_global_parameters,)) + else: + initial_global_theta_flat = torch.full( + size=(self.n_global_parameters,), + fill_value=initial_global_parameter_value + ) + self.global_theta_flat = nn.Parameter(initial_global_theta_flat) def forward(self, x: torch.Tensor, context: torch.Tensor = None): - # x.shape = (*batch_shape, *input_event_shape) - # context.shape = (*batch_shape, *context_shape) - # output.shape = (*batch_shape, *output_event_shape, n_predicted_parameters) + # x.shape = (*batch_shape, *self.input_event_shape) + # context.shape = (*batch_shape, *self.context_shape) + # output.shape = (*batch_shape, *self.parameter_shape) + batch_shape = get_batch_shape(x, self.input_event_shape) + if self.n_global_parameters == 0: + # All parameters are predicted + return self.predict_theta_flat(x, context).view(*batch_shape, *self.parameter_shape) + else: + if self.n_global_parameters == self.n_transformer_parameters: + # All transformer parameters are learned globally + output = torch.zeros(*batch_shape, *self.parameter_shape) + output[..., self.global_parameter_mask] = self.global_theta_flat + return output + else: + # Some transformer parameters are learned globally, some are predicted + output = torch.zeros(*batch_shape, *self.parameter_shape) + output[..., self.global_parameter_mask] = self.global_theta_flat + output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context) + return output + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError class Constant(ConditionerTransform): - def __init__(self, output_event_shape, n_parameters: int, fill_value: float = None): + def __init__(self, event_shape, parameter_shape, fill_value: float = None): super().__init__( - input_event_shape=None, + input_event_shape=event_shape, context_shape=None, - output_event_shape=output_event_shape, - n_predicted_parameters=n_parameters + parameter_shape=parameter_shape, + initial_global_parameter_value=fill_value, + global_parameter_mask=torch.ones(parameter_shape, dtype=torch.bool) ) - if fill_value is None: - initial_theta = torch.randn(size=(*self.output_event_shape, n_parameters,)) - else: - initial_theta = torch.full(size=(*self.output_event_shape, n_parameters), fill_value=fill_value) - self.theta = nn.Parameter(initial_theta) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - n_batch_dims = len(x.shape) - len(self.output_event_shape) - n_event_dims = len(self.output_event_shape) - batch_shape = x.shape[:n_batch_dims] - return pad_leading_dims(self.theta, n_batch_dims).repeat(*batch_shape, *([1] * n_event_dims), 1) class MADE(ConditionerTransform): + """ + Masked autoencoder for distribution estimation (MADE). + + MADE is a conditioner transform that receives as input a tensor x. It predicts parameters for the + transformer such that each dimension only depends on the previous ones. + """ + class MaskedLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, mask: torch.Tensor): super().__init__(in_features=in_features, out_features=out_features) @@ -68,18 +125,21 @@ def forward(self, x): return nn.functional.linear(x, self.weight * self.mask, self.bias) def __init__(self, - input_event_shape: torch.Size, - output_event_shape: torch.Size, - n_predicted_parameters: int, - context_shape: torch.Size = None, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + output_event_shape: Union[torch.Size, Tuple[int, ...]], + parameter_shape_per_element: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, - n_layers: int = 2): + n_layers: int = 2, + **kwargs): super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + parameter_shape=(*output_event_shape, *parameter_shape_per_element), + **kwargs ) + n_predicted_parameters_per_element = int(torch.prod(torch.as_tensor(parameter_shape_per_element))) + n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) @@ -88,7 +148,7 @@ def __init__(self, ms = [ torch.arange(self.n_input_event_dims) + 1, *[(torch.arange(n_hidden) % (self.n_input_event_dims - 1)) + 1 for _ in range(n_layers - 1)], - torch.arange(self.n_output_event_dims) + 1 + torch.arange(n_output_event_dims) + 1 ] # Create autoencoder masks @@ -103,10 +163,9 @@ def __init__(self, layers.extend([ self.MaskedLinear( masks[-1].shape[1], - masks[-1].shape[0] * n_predicted_parameters, - torch.repeat_interleave(masks[-1], n_predicted_parameters, dim=0) - ), - nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters)) + masks[-1].shape[0] * n_predicted_parameters_per_element, + torch.repeat_interleave(masks[-1], n_predicted_parameters_per_element, dim=0) + ) ]) self.sequential = nn.Sequential(*layers) @@ -123,61 +182,54 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + theta = self.sequential(self.context_combiner(x, context)) + # (*b, *e, *pe) + + if self.global_parameter_mask is None: + return torch.flatten(theta, start_dim=len(theta.shape) - len(self.input_event_shape)) + else: + return theta[..., ~self.global_parameter_mask] class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_predicted_parameters: int, - **kwargs): - super().__init__(input_event_shape, output_event_shape, n_predicted_parameters, n_layers=1, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, n_layers=1, **kwargs) class FeedForward(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, - n_predicted_parameters: int, + parameter_shape: torch.Size, context_shape: torch.Size = None, n_hidden: int = None, - n_layers: int = 2): + n_layers: int = 2, + **kwargs): super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + parameter_shape=parameter_shape, + **kwargs ) if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) - # If context given, concatenate it to transform input - if context_shape is not None: - self.n_input_event_dims += self.n_context_dims - layers = [] - - # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) - layers.append(nn.Linear(n_hidden, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError - - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + return self.sequential(self.context_combiner(x, context)) class Linear(FeedForward): @@ -186,44 +238,49 @@ def __init__(self, *args, **kwargs): class ResidualFeedForward(ConditionerTransform): - class ResidualLinear(nn.Module): - def __init__(self, n_in, n_out): + class ResidualBlock(nn.Module): + def __init__(self, event_size: int, hidden_size: int, block_size: int): super().__init__() - self.linear = nn.Linear(n_in, n_out) + if block_size < 2: + raise ValueError(f"block_size must be at least 2 but found {block_size}. " + f"For block_size = 1, use the FeedForward class instead.") + layers = [] + layers.extend([nn.Linear(event_size, hidden_size), nn.ReLU()]) + for _ in range(block_size - 2): + layers.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(hidden_size, event_size)]) + self.sequential = nn.Sequential(*layers) def forward(self, x): - return x + self.linear(x) + return x + self.sequential(x) def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, - n_predicted_parameters: int, + parameter_shape: torch.Size, context_shape: torch.Size = None, - n_layers: int = 2): - super().__init__(input_event_shape, context_shape, output_event_shape, n_predicted_parameters) - - # If context given, concatenate it to transform input - if context_shape is not None: - self.n_input_event_dims += self.n_context_dims + n_hidden: int = None, + n_layers: int = 3, + block_size: int = 2, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=context_shape, + parameter_shape=parameter_shape, + **kwargs + ) - layers = [] + if n_hidden is None: + n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) - # Check the one layer special case - if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) - elif n_layers > 1: - layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - for _ in range(n_layers - 2): - layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) - else: - raise ValueError + if n_layers <= 2: + raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers = [nn.Linear(self.n_input_event_dims, n_hidden), nn.ReLU()] + for _ in range(n_layers - 2): + layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size)) + layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + return self.sequential(self.context_combiner(x, context)) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py deleted file mode 100644 index 2805573..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Union, Tuple - -import torch - -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform -from normalizing_flows.utils import get_batch_shape - - -class Coupling(Conditioner): - def __init__(self, event_shape: torch.Size, constants: torch.Tensor): - """ - Coupling conditioner. - - - Note: Always treats the first n_dim // 2 dimensions as constant. - Shuffling is handled in Permutation bijections. - - :param constants: - """ - super().__init__() - self.event_shape = event_shape - - # TODO add support for other kinds of masks - n_total_dims = int(torch.prod(torch.tensor(event_shape))) - self.n_constant_dims = n_total_dims // 2 - self.n_changed_dims = n_total_dims - self.n_constant_dims - - self.constant_mask = torch.less(torch.arange(n_total_dims).view(*event_shape), self.n_constant_dims) - self.register_buffer('constants', constants) # Takes care of torch devices - - @property - @torch.no_grad() - def input_shape(self): - return (int(torch.sum(self.constant_mask)),) - - @property - @torch.no_grad() - def output_shape(self): - return (int(torch.sum(~self.constant_mask)),) - - def forward(self, - x: torch.Tensor, - transform: ConditionerTransform, - context: torch.Tensor = None, - return_mask: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # Predict transformer parameters for output dimensions - batch_shape = get_batch_shape(x, self.event_shape) - x_const = x.view(*batch_shape, *self.event_shape)[..., self.constant_mask] - tmp = transform(x_const, context=context) - n_parameters = tmp.shape[-1] - - # Create full parameter tensor - h = torch.empty(size=(*batch_shape, *self.event_shape, n_parameters), dtype=x.dtype, device=x.device) - - # Fill the parameter tensor with predicted values - h[..., ~self.constant_mask, :] = tmp - h[..., self.constant_mask, :] = self.constants - - if return_mask: - # Return the parameters for the to-be-transformed partition and the partition mask itself - return h, self.constant_mask - - return h diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py new file mode 100644 index 0000000..04ddc8f --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py @@ -0,0 +1,44 @@ +import torch + + +class CouplingMask: + """ + Base object which holds coupling partition mask information. + """ + + def __init__(self, event_shape): + self.event_shape = event_shape + self.event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + + @property + def mask(self): + raise NotImplementedError + + @property + def constant_event_size(self): + raise NotImplementedError + + @property + def transformed_event_size(self): + raise NotImplementedError + + +class HalfSplit(CouplingMask): + def __init__(self, event_shape): + super().__init__(event_shape) + self.event_partition_mask = torch.less( + torch.arange(self.event_size).view(*self.event_shape), + self.constant_event_size + ) + + @property + def constant_event_size(self): + return self.event_size // 2 + + @property + def transformed_event_size(self): + return self.event_size - self.constant_event_size + + @property + def mask(self): + return self.event_partition_mask diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py deleted file mode 100644 index ee6a74f..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner - - -class MaskedAutoregressive(Conditioner): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, transform, context: torch.Tensor = None) -> torch.Tensor: - return transform(x, context=context) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index e9e203f..b340cbe 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,11 +1,11 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import MADE, FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling -from normalizing_flows.bijections.finite.autoregressive.conditioners.masked import MaskedAutoregressive -from normalizing_flows.bijections.finite.autoregressive.layers_base import ForwardMaskedAutoregressiveBijection, \ +from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import FeedForward +from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit +from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection -from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Scale, Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ UnconstrainedMonotonicNeuralNetwork from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational @@ -16,28 +16,29 @@ from normalizing_flows.bijections.base import invert +# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files. class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseScale(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Scale(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseShift(ElementwiseBijection): def __init__(self, event_shape): transformer = Shift(event_shape) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseRQSpline(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = RationalQuadratic(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class AffineCoupling(CouplingBijection): @@ -47,16 +48,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - transformer = Affine(event_shape=event_shape) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class InverseAffineCoupling(CouplingBijection): @@ -66,16 +66,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - transformer = Affine(event_shape=event_shape).invert() - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class ShiftCoupling(CouplingBijection): @@ -83,16 +82,15 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - transformer = Shift(event_shape=event_shape) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Shift(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class LRSCoupling(CouplingBijection): @@ -102,16 +100,15 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 - transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = LinearRational(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class RQSCoupling(CouplingBijection): @@ -120,16 +117,15 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = RationalQuadratic(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class DSCoupling(CouplingBijection): @@ -138,18 +134,20 @@ def __init__(self, context_shape: torch.Size = None, n_hidden_layers: int = 2, **kwargs): - transformer = DeepSigmoid(event_shape=event_shape, n_hidden_layers=n_hidden_layers) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = DeepSigmoid( + event_shape=torch.Size((coupling_mask.transformed_event_size,)), + n_hidden_layers=n_hidden_layers + ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class LinearAffineCoupling(AffineCoupling): @@ -172,68 +170,33 @@ def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape, **kwargs, n_layers=1) -class AffineForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - transformer = Affine(event_shape=event_shape) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + transformer: ScalarTransformer = Affine(event_shape=event_shape) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) -class RQSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + -class LRSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -241,20 +204,8 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - transformer = invert(Affine(event_shape=event_shape)) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + transformer: ScalarTransformer = invert(Affine(event_shape=event_shape)) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -264,44 +215,20 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) -class UMNNMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 1, hidden_dim: int = 5, **kwargs): - transformer = UnconstrainedMonotonicNeuralNetwork( + transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, n_hidden_layers=n_hidden_layers, hidden_dim=hidden_dim ) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index beddbb6..83fa848 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,18 +1,24 @@ -from typing import Tuple +from typing import Tuple, Optional, Union import torch from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant -from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \ + MADE +from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape class AutoregressiveBijection(Bijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(event_shape=transformer.event_shape) + def __init__(self, + event_shape, + conditioner: Optional[Conditioner], + transformer: Union[TensorTransformer, ScalarTransformer], + conditioner_transform: ConditionerTransform, + **kwargs): + super().__init__(event_shape=event_shape) self.conditioner = conditioner self.conditioner_transform = conditioner_transform self.transformer = transformer @@ -29,30 +35,95 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class CouplingBijection(AutoregressiveBijection): - def __init__(self, conditioner: Coupling, transformer: Transformer, conditioner_transform: ConditionerTransform, + """ + Base coupling bijection object. + + A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner. + + The coupling conditioner receives as input an event tensor x. + It then partitions an input event tensor x into a constant part x_A and a modifiable part x_B. + For x_A, the conditioner outputs a set of parameters which is always the same. + For x_B, the conditioner outputs a set of parameters which are predicted from x_A. + + Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is x_A + and the second half is x_B. When using this in a normalizing flow, permutation layers can shuffle event dimensions. + + For improved performance, this implementation does not use a standalone coupling conditioner. It instead implements + a method to partition x into x_A and x_B and then predict parameters for x_B. + """ + + def __init__(self, + transformer: TensorTransformer, + coupling_mask: CouplingMask, + conditioner_transform: ConditionerTransform, **kwargs): - super().__init__(conditioner, transformer, conditioner_transform, **kwargs) + super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs) + self.coupling_mask = coupling_mask - # We need to change the transformer event shape because it will no longer accept full-shaped events, but only - # a flattened selection of event dimensions. - self.transformer.event_shape = torch.Size((self.conditioner.n_changed_dims,)) + assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,) + assert transformer.event_shape == (self.coupling_mask.transformed_event_size,) + + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): + """ + Partition tensor x and compute transformer parameters. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape) to be partitioned into x_A and x_B. + :param context: context tensor with context.shape = (*batch_shape, *context.shape). + :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). + """ + # Predict transformer parameters for output dimensions + x_a = x[..., self.coupling_mask.mask] # (*b, constant_event_size) + h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) + return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() - h, mask = self.conditioner(x, self.conditioner_transform, context, return_mask=True) - z[..., ~mask], log_det = self.transformer.forward(x[..., ~mask], h[..., ~mask, :]) + h_b = self.partition_and_predict_parameters(x, context) + z[..., ~self.coupling_mask.mask], log_det = self.transformer.forward(x[..., ~self.coupling_mask.mask], h_b) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() - h, mask = self.conditioner(z, self.conditioner_transform, context, return_mask=True) - x[..., ~mask], log_det = self.transformer.inverse(z[..., ~mask], h[..., ~mask, :]) + h_b = self.partition_and_predict_parameters(x, context) + x[..., ~self.coupling_mask.mask], log_det = self.transformer.inverse(z[..., ~self.coupling_mask.mask], h_b) return x, log_det -class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) +class MaskedAutoregressiveBijection(AutoregressiveBijection): + """ + Masked autoregressive bijection class. + + This bijection is specified with a scalar transformer. + Its conditioner is always MADE, which receives as input a tensor x with x.shape = (*batch_shape, *event_shape). + MADE outputs parameters h for the scalar transformer with + h.shape = (*batch_shape, *event_shape, *parameter_shape_per_element). + The transformer then applies the bijection elementwise. + """ + + def __init__(self, + event_shape, + context_shape, + transformer: ScalarTransformer, + **kwargs): + conditioner_transform = MADE( + input_event_shape=event_shape, + output_event_shape=event_shape, + parameter_shape_per_element=transformer.parameter_shape_per_element, + context_shape=context_shape, + **kwargs + ) + super().__init__(transformer.event_shape, None, transformer, conditioner_transform) + + def apply_conditioner_transformer(self, inputs, context, forward: bool = True): + h = self.conditioner_transform(inputs, context) + if forward: + outputs, log_det = self.transformer.forward(inputs, h) + else: + outputs, log_det = self.transformer.inverse(inputs, h) + return outputs, log_det + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_conditioner_transformer(x, context, True) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) @@ -61,34 +132,35 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. x_flat = flatten_event(torch.clone(z), self.event_shape) for i in torch.arange(n_event_dims): x_clone = unflatten_event(torch.clone(x_flat), self.event_shape) - h = self.conditioner( - x_clone, - transform=self.conditioner_transform, - context=context - ) - tmp, log_det = self.transformer.inverse(x_clone, h) + tmp, log_det = self.apply_conditioner_transformer(x_clone, context, False) x_flat[..., i] = flatten_event(tmp, self.event_shape)[..., i] x = unflatten_event(x_flat, self.event_shape) return x, log_det -class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) - self.forward_layer = ForwardMaskedAutoregressiveBijection( - conditioner, - transformer, - conditioner_transform - ) +class InverseMaskedAutoregressiveBijection(MaskedAutoregressiveBijection): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.forward_layer.inverse(x, context=context) + return super().inverse(x, context=context) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.forward_layer.forward(z, context=context) + return super().forward(z, context=context) class ElementwiseBijection(AutoregressiveBijection): - def __init__(self, transformer: Transformer, n_transformer_parameters: int): - super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, n_transformer_parameters)) - # TODO override forward and inverse to save on space + """ + Base elementwise bijection class. + + Applies a bijective transformation to each element of the input tensor. + The bijection for each element has its own set of globally learned parameters. + """ + + def __init__(self, transformer: ScalarTransformer, fill_value: float = None): + super().__init__( + transformer.event_shape, + NullConditioner(), + transformer, + Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value) + ) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py index b23ecdf..a24701d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py @@ -5,28 +5,77 @@ from normalizing_flows.bijections.base import Bijection -class Transformer(Bijection): +class TensorTransformer(Bijection): + """ + Base transformer class. + + A transformer receives as input a tensor x with x.shape = (*batch_shape, *event_shape) and parameters h + with h.shape = (*batch_shape, *parameter_shape). It applies a bijective map to each tensor in the batch + with its corresponding parameter set. In general, the parameters are used to transform the entire tensor at + once. As a special case, the subclass ScalarTransformer transforms each element of an input event + individually. + """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape=event_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the forward transformation. + + :param torch.Tensor x: input tensor with shape (*batch_shape, *event_shape). + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape). + :returns: output tensor with shape (*batch_shape, *event_shape). + """ raise NotImplementedError def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the inverse transformation. + + :param torch.Tensor x: input tensor with shape (*batch_shape, *event_shape). + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape). + :returns: output tensor with shape (*batch_shape, *event_shape). + """ + raise NotImplementedError + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: raise NotImplementedError @property def n_parameters(self) -> int: + return int(torch.prod(torch.as_tensor(self.parameter_shape))) + + @property + def default_parameters(self) -> torch.Tensor: """ - Number of parameters that parametrize this transformer. Example: rational quadratic splines require 3*b-1 where - b is the number of bins. An affine transformation requires 2 (typically corresponding to the unconstrained scale - and shift). + Set of parameters which ensures an identity transformation. """ raise NotImplementedError + +class ScalarTransformer(TensorTransformer): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + super().__init__(event_shape) + @property - def default_parameters(self) -> torch.Tensor: + def parameter_shape_per_element(self): """ - Set of parameters which ensures an identity transformation. + The shape of parameters that transform a single element of an input tensor. + + Example: + * if using an affine transformer, this is equal to (2,) (corresponding to scale and shift). + * if using a rational quadratic spline transformer, this is equal to (3 * b - 1,) where b is the + number of bins. """ raise NotImplementedError + + @property + def n_parameters_per_element(self): + return int(torch.prod(torch.as_tensor(self.parameter_shape_per_element))) + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + # Scalar transformers map each element individually, so the first dimensions are the event shape + return torch.Size((*self.event_shape, *self.parameter_shape_per_element)) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py index fb431ed..661fda2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py @@ -1,23 +1,23 @@ import torch -from typing import Tuple, List +from typing import Tuple, List, Union -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import get_batch_shape -class Combination(Transformer): - def __init__(self, event_shape: torch.Size, components: List[Transformer]): +class Combination(ScalarTransformer): + def __init__(self, event_shape: torch.Size, components: List[ScalarTransformer]): super().__init__(event_shape) self.components = components self.n_components = len(self.components) @property - def n_parameters(self) -> int: - return sum([c.n_parameters for c in self.components]) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (sum([c.n_parameters_per_element for c in self.components]),) @property def default_parameters(self) -> torch.Tensor: - return torch.cat([c.default_parameters for c in self.components], dim=0) + return torch.cat([c.default_parameters.ravel() for c in self.components], dim=0) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # h.shape = (*batch_size, *event_shape, n_components * n_output_parameters) @@ -27,9 +27,9 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch start_index = 0 for i in range(self.n_components): component = self.components[i] - x, log_det_increment = component.forward(x, h[..., start_index:start_index + component.n_parameters]) + x, log_det_increment = component.forward(x, h[..., start_index:start_index + component.n_parameters_per_element]) log_det += log_det_increment - start_index += component.n_parameters + start_index += component.n_parameters_per_element z = x return z, log_det @@ -38,11 +38,11 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch # We assume last dim is ordered as [c1, c2, ..., ck] i.e. sequence of parameter vectors, one for each component. batch_shape = get_batch_shape(z, self.event_shape) log_det = torch.zeros(size=batch_shape) - c = self.n_parameters + c = self.n_parameters_per_element for i in range(self.n_components): component = self.components[self.n_components - i - 1] - c -= component.n_parameters - z, log_det_increment = component.inverse(z, h[..., c:c + component.n_parameters]) + c -= component.n_parameters_per_element + z, log_det_increment = component.inverse(z, h[..., c:c + component.n_parameters_per_element]) log_det += log_det_increment x = z return x, log_det diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 43f259b..e470def 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -2,7 +2,7 @@ from typing import Tuple, Union, List import torch import torch.nn as nn -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.combination.base import Combination from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid_util import log_softmax, \ log_sigmoid, log_dot @@ -19,7 +19,7 @@ def inverse_sigmoid(p): return torch.log(p) - torch.log1p(-p) -class Sigmoid(Transformer): +class Sigmoid(ScalarTransformer): """ Applies z = inv_sigmoid(w.T @ sigmoid(a * x + b)) where a > 0, w > 0 and sum(w) = 1. Note: w, a, b are vectors, so multiplication a * x is broadcast. @@ -40,12 +40,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], super().__init__(event_shape) @property - def n_parameters(self) -> int: - return 3 * self.hidden_dim - - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (3 * self.hidden_dim,) def extract_parameters(self, h: torch.Tensor): """ @@ -212,7 +208,7 @@ def forward_1d(self, x, h, eps: float = 1e-6): return z, log_det.view(*x.shape[:2]) -class DenseSigmoid(Transformer): +class DenseSigmoid(ScalarTransformer): """ Apply y = f1 \\circ f2 \\circ ... \\circ fn (x) where * f1 is a dense sigmoid inner transform which maps from 1 to h dimensions; @@ -234,12 +230,12 @@ def __init__(self, self.layers = nn.ModuleList(layers) @property - def n_parameters(self) -> int: - return sum([layer.n_parameters for layer in self.layers]) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (sum([layer.n_parameters for layer in self.layers]),) @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) # TODO set up parametrization with deltas so this holds + return torch.zeros(size=self.parameter_shape) # TODO set up parametrization with deltas so this holds def split_parameters(self, h): # split parameters h into parameters for several layers diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py index 56ab9fc..af2cc78 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py @@ -3,12 +3,12 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.numerical_inversion import bisection from normalizing_flows.utils import get_batch_shape, sum_except_batch -class Integration(Transformer): +class Integration(ScalarTransformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], bound: float = 100.0, eps: float = 1e-6): """ :param bound: specifies the initial interval [-bound, bound] where numerical inversion is performed. @@ -61,16 +61,16 @@ def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, to def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ x.shape = (*batch_shape, *event_shape) - h.shape = (*batch_shape, *event_shape, n_parameters) + h.shape = (*batch_shape, *parameter_shape) """ - z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters)) + z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element)) z = z_flat.view_as(x) batch_shape = get_batch_shape(x, self.event_shape) log_det = sum_except_batch(log_det_flat.view(*batch_shape, *self.event_shape), self.event_shape) return z, log_det def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters)) + x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters_per_element)) x = x_flat.view_as(z) batch_shape = get_batch_shape(z, self.event_shape) log_det = sum_except_batch(log_det_flat.view(*batch_shape, *self.event_shape), self.event_shape) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py index c0239b8..3036dc7 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py @@ -24,6 +24,12 @@ def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[ class UnconstrainedMonotonicNeuralNetwork(UnconstrainedMonotonicTransformer): + """ + Unconstrained monotonic neural network transformer. + + The unconstrained monotonic neural network is a neural network with positive weights and positive activation + function derivatives. These two conditions ensure its invertibility. + """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_hidden_layers: int = 2, @@ -44,42 +50,40 @@ def __init__(self, # weight is a square matrix, bias is a vector self.n_hidden_params = (self.hidden_dim ** 2 + self.hidden_dim) * self.n_hidden_layers - self._sampled_default_params = torch.randn(size=(self.n_parameters,)) / 1000 + self._sampled_default_params = torch.randn(size=(self.n_dim, *self.parameter_shape_per_element)) / 1000 @property - def n_parameters(self) -> int: - return self.n_input_params + self.n_output_params + self.n_hidden_params + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple]: + return (self.n_input_params + self.n_output_params + self.n_hidden_params,) @property def default_parameters(self) -> torch.Tensor: return self._sampled_default_params def compute_parameters(self, h: torch.Tensor): - batch_shape = h.shape[:-1] p0 = self.default_parameters + batch_size = h.shape[0] + n_events = batch_size // p0.shape[0] # Input layer - input_layer_defaults = pad_leading_dims(p0[:self.n_input_params], len(h.shape) - 1) + input_layer_defaults = p0[..., :self.n_input_params].repeat(n_events, 1) input_layer_deltas = h[..., :self.n_input_params] / self.const input_layer_params = input_layer_defaults + input_layer_deltas - input_layer_params = input_layer_params.view(*batch_shape, self.hidden_dim, 2) + input_layer_params = input_layer_params.view(batch_size, self.hidden_dim, 2) # Output layer - output_layer_defaults = pad_leading_dims(p0[-self.n_output_params:], len(h.shape) - 1) + output_layer_defaults = p0[..., -self.n_output_params:].repeat(n_events, 1) output_layer_deltas = h[..., -self.n_output_params:] / self.const output_layer_params = output_layer_defaults + output_layer_deltas - output_layer_params = output_layer_params.view(*batch_shape, 1, self.hidden_dim + 1) + output_layer_params = output_layer_params.view(batch_size, 1, self.hidden_dim + 1) # Hidden layers - hidden_layer_defaults = pad_leading_dims( - p0[self.n_input_params:self.n_input_params + self.n_hidden_params], - len(h.shape) - 1 - ) + hidden_layer_defaults = p0[..., self.n_input_params:self.n_input_params + self.n_hidden_params].repeat(n_events, 1) hidden_layer_deltas = h[..., self.n_input_params:self.n_input_params + self.n_hidden_params] / self.const hidden_layer_params = hidden_layer_defaults + hidden_layer_deltas hidden_layer_params = torch.chunk(hidden_layer_params, chunks=self.n_hidden_layers, dim=-1) hidden_layer_params = [ - layer.view(*batch_shape, self.hidden_dim, self.hidden_dim + 1) + layer.view(batch_size, self.hidden_dim, self.hidden_dim + 1) for layer in hidden_layer_params ] return [input_layer_params, *hidden_layer_params, output_layer_params] @@ -112,28 +116,18 @@ def neural_network_forward(inputs, parameters: List[torch.Tensor]): out = 1 + torch.nn.functional.elu(out) return out - @staticmethod - def reshape_tensors(x: torch.Tensor, h: List[torch.Tensor]): - # batch_shape = get_batch_shape(x, self.event_shape) - # batch_dims = int(torch.as_tensor(batch_shape).prod()) - # event_dims = int(torch.as_tensor(self.event_shape).prod()) - flattened_dim = int(torch.as_tensor(x.shape).prod()) - x_r = x.view(flattened_dim, 1, 1) - h_r = [p.view(flattened_dim, *p.shape[-2:]) for p in h] - return x_r, h_r - def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - x_r, p_r = self.reshape_tensors(x, params) - integral_flat = self.integral(x_r, p_r) - log_det_flat = self.g(x_r, p_r).log() # We can apply log since g is always positive + x_r = x.view(-1, 1, 1) + integral_flat = self.integral(x_r, params) + log_det_flat = self.g(x_r, params).log() # We can apply log since g is always positive output = integral_flat.view_as(x) log_det = log_det_flat.view_as(x) return output, log_det def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: params = self.compute_parameters(h) - z_r, p_r = self.reshape_tensors(z, params) - x_flat = self.inverse_1d_without_log_det(z_r, p_r) + z_r = z.view(-1, 1, 1) + x_flat = self.inverse_1d_without_log_det(z_r, params) outputs = x_flat.view_as(z) - log_det = -self.g(x_flat, p_r).log().view_as(z) + log_det = -self.g(x_flat, params).log().view_as(z) return outputs, log_det diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/__init__.py similarity index 100% rename from normalizing_flows/bijections/finite/autoregressive/transformers/integration.py rename to normalizing_flows/bijections/finite/autoregressive/transformers/linear/__init__.py diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/affine.py similarity index 91% rename from normalizing_flows/bijections/finite/autoregressive/transformers/affine.py rename to normalizing_flows/bijections/finite/autoregressive/transformers/linear/affine.py index 4e0b2ba..9b2bb2b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/affine.py @@ -3,11 +3,11 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import get_batch_shape, sum_except_batch -class Affine(Transformer): +class Affine(ScalarTransformer): """ Affine transformer. @@ -23,14 +23,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): self.const = 2 @property - def n_parameters(self) -> int: - return 2 + def parameter_shape_per_element(self): + return (2,) @property def default_parameters(self) -> torch.Tensor: - default_u_alpha = torch.zeros(size=(1,)) - default_u_beta = torch.zeros(size=(1,)) - return torch.cat([default_u_alpha, default_u_beta], dim=0) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] @@ -55,7 +53,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det -class Affine2(Transformer): +class Affine2(ScalarTransformer): """ Affine transformer with near-identity initialization. @@ -110,17 +108,17 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det -class Shift(Transformer): +class Shift(ScalarTransformer): def __init__(self, event_shape: torch.Size): super().__init__(event_shape=event_shape) @property - def n_parameters(self) -> int: - return 1 + def parameter_shape_per_element(self): + return (1,) @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: beta = h[..., 0] @@ -135,7 +133,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return z - beta, log_det -class Scale(Transformer): +class Scale(ScalarTransformer): """ Scaling transformer. @@ -150,11 +148,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): self.u_alpha_1 = math.log(1 - self.m) @property - def n_parameters(self) -> int: - return 1 + def parameter_shape_per_element(self): + return (1,) + @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py new file mode 100644 index 0000000..c649a61 --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py @@ -0,0 +1,53 @@ +from typing import Union, Tuple +import torch + +from normalizing_flows.bijections import LU +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.matrix import LUTransformer +from normalizing_flows.utils import sum_except_batch, get_batch_shape + + +class Invertible1x1Convolution(TensorTransformer): + """ + Invertible 1x1 convolution. + + This transformer receives as input a batch of images x with x.shape (*batch_shape, *image_dimensions, channels) and + parameters h for an invertible linear transform of the channels + with h.shape = (*batch_shape, *image_dimensions, *parameter_shape). + Note that image_dimensions can be a shape with arbitrarily ordered dimensions (height, width). + In fact, it is not required that the image is two-dimensional. Voxels with shape (height, width, depth, channels) + are also supported, as well as tensors with more general shapes. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + super().__init__(event_shape) + *self.image_dimensions, self.n_channels = event_shape + self.invertible_linear: TensorTransformer = LUTransformer(event_shape=(self.n_channels,)) + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + return self.invertible_linear.parameter_shape + + @property + def default_parameters(self) -> torch.Tensor: + return self.invertible_linear.default_parameters + + def apply_linear(self, inputs: torch.Tensor, h: torch.Tensor, forward: bool): + batch_shape = get_batch_shape(inputs, self.event_shape) + + # Apply linear transformation along channel dimension + if forward: + outputs, log_det = self.invertible_linear.forward(inputs, h) + else: + outputs, log_det = self.invertible_linear.inverse(inputs, h) + log_det = sum_except_batch( + log_det.view(*batch_shape, *self.image_dimensions), + event_shape=self.image_dimensions + ) + return outputs, log_det + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_linear(x, h, forward=True) + + def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_linear(z, h, forward=False) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py new file mode 100644 index 0000000..a41c7e5 --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py @@ -0,0 +1,99 @@ +from typing import Union, Tuple + +import torch + +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.utils import flatten_event, unflatten_event + + +# Matrix transformers that operate on vector inputs (Ax=b) + +class LUTransformer(TensorTransformer): + """Linear transformer with LUx = y. + + It is assumed that all diagonal elements of L are 1. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + super().__init__(event_shape) + + def extract_matrices(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract matrices L, U from tensor h. + + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape) + :returns: tuple with (L, U, log(diag(U))). L and U have shapes (*batch_shape, event_size, event_size), + log(diag(U)) has shape (*batch_shape, event_size). + """ + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + n_off_diag_el = (event_size ** 2 - event_size) // 2 + + u_unc_diag = h[..., :event_size] + u_diag = torch.exp(u_unc_diag) / 10 + 1 + u_log_diag = torch.log(u_diag) + + u_off_diagonal_elements = h[..., event_size:event_size + n_off_diag_el] / 10 + l_off_diagonal_elements = h[..., -n_off_diag_el:] / 10 + + batch_shape = h.shape[:-len(self.parameter_shape)] + + upper = torch.zeros(size=(*batch_shape, event_size, event_size)) + upper_row_index, upper_col_index = torch.triu_indices(row=event_size, col=event_size, offset=1) + upper[..., upper_row_index, upper_col_index] = u_off_diagonal_elements + upper[..., range(event_size), range(event_size)] = u_diag + + lower = torch.zeros(size=(*batch_shape, event_size, event_size)) + lower_row_index, lower_col_index = torch.tril_indices(row=event_size, col=event_size, offset=-1) + lower[..., lower_row_index, lower_col_index] = l_off_diagonal_elements + lower[..., range(event_size), range(event_size)] = 1 # Unit diagonal + + return lower, upper, u_log_diag + + @staticmethod + def log_determinant(upper_log_diag: torch.Tensor): + """ + Computes the matrix log determinant of A = LU for each pair of matrices in a batch. + + Note: det(A) = det(LU) = det(L) * det(U) so log det(A) = log det(L) + log det(U). + We assume that L has unit diagonal, so log det(L) = 0 and can be skipped. + + :param torch.Tensor upper_log_diag: log diagonals of matrices U with shape (*batch_size, event_size). + :returns: log determinants of LU with shape (*batch_size,). + """ + # Extract the diagonals + return torch.sum(upper_log_diag, dim=-1) + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + lower, upper, upper_log_diag = self.extract_matrices(h) + + # Flatten inputs + x_flat = flatten_event(x, self.event_shape) # (*batch_shape, event_size) + y_flat = torch.einsum('...ij,...jk,...k->...i', lower, upper, x_flat) # y = LUx + + output = unflatten_event(y_flat, self.event_shape) + return output, self.log_determinant(upper_log_diag) + + def inverse(self, y: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + lower, upper, upper_log_diag = self.extract_matrices(h) + + # Flatten inputs + y_flat = flatten_event(y, self.event_shape)[..., None] # (*batch_shape, event_size) + z_flat = torch.linalg.solve_triangular(lower, y_flat, upper=False, unitriangular=True) # y = Lz => z = L^{-1}y + x_flat = torch.linalg.solve_triangular(upper, z_flat, upper=True, unitriangular=False) # z = Ux => x = U^{-1}z + x_flat = x_flat.squeeze(-1) + + output = unflatten_event(x_flat, self.event_shape) + return output, -self.log_determinant(upper_log_diag) + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + # Let n be the event size + # L will have (n^2 - n) / 2 parameters (we assume unit diagonal) + # U will have (n^2 - n) / 2 + n parameters + n_off_diag_el = (event_size ** 2 - event_size) // 2 + return (event_size + n_off_diag_el + n_off_diag_el,) + + @property + def default_parameters(self) -> torch.Tensor: + return torch.zeros(size=self.parameter_shape) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py index 2caa2dc..90dadac 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py @@ -2,11 +2,11 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import sum_except_batch -class MonotonicSpline(Transformer): +class MonotonicSpline(ScalarTransformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], min_input: float = -1.0, @@ -23,7 +23,7 @@ def __init__(self, self.n_knots = n_bins + 1 @property - def n_parameters(self) -> int: + def parameter_shape_per_element(self) -> int: raise NotImplementedError def forward_inputs_inside_bounds_mask(self, x): diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py index e7d03c0..12f7cca 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py @@ -13,12 +13,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_bins: int self.const = 1000 @property - def n_parameters(self) -> int: - return 2 * self.n_bins + 2 - - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) + def n_parameters(self) -> torch.Size: + return torch.Size((2 * self.n_bins + 2,)) def compute_spline_parameters(self, knots_x: torch.Tensor, knots_y: torch.Tensor, idx: torch.Tensor): # knots_x.shape == (n, n_knots) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py index d2ad3d0..e9b0d17 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py @@ -35,12 +35,8 @@ def compute_bin_y(self, delta): return cs * (self.max_output - self.min_output) + self.min_output @property - def n_parameters(self) -> int: - return self.n_bins - - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_bins,)) + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((self.n_bins,)) def forward_1d(self, x, h): assert len(x.shape) == 1 diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 300dd9f..a80e887 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -16,23 +16,16 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl max_output=boundary, **kwargs ) - self.min_bin_width = 1e-3 - self.min_bin_height = 1e-3 + self.min_bin_width = 1e-2 + self.min_bin_height = 1e-2 self.min_d = 1e-5 self.const = math.log(math.exp(1 - self.min_d) - 1) # to ensure identity initialization + self.eps = 5e-10 # Epsilon for numerical stability when computing forward/inverse @property - def n_parameters(self) -> int: - return 4 * self.n_bins + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((4 * self.n_bins,)) - @property - def default_parameters(self) -> torch.Tensor: - default_u_x = torch.zeros(size=(self.n_bins,)) - default_u_y = torch.zeros(size=(self.n_bins,)) - default_u_lambda = torch.zeros(size=(self.n_bins,)) - default_u_d = torch.zeros(size=(self.n_bins - 1,)) - default_u_w0 = torch.zeros(size=(1,)) - return torch.cat([default_u_x, default_u_y, default_u_lambda, default_u_d, default_u_w0], dim=0) def compute_parameters(self, idx, knots_x, knots_y, knots_d, knots_lambda, u_w0): assert knots_x.shape == knots_y.shape == knots_d.shape @@ -120,7 +113,7 @@ def forward_1d(self, x, h): ) log_det_phi_lt_lambda = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - 2 * torch.log(w_k * (lambda_k - phi) + w_m * phi) + - torch.log((w_k * (lambda_k - phi) + w_m * phi) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -130,7 +123,7 @@ def forward_1d(self, x, h): ) log_det_phi_gt_lambda = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) + - torch.log((w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -166,7 +159,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_lt_ym = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2) + - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) ) @@ -176,7 +169,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_gt_ym = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) + - torch.log((w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) ) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py index cf5ec86..9718197 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py @@ -38,15 +38,9 @@ def __init__(self, self.boundary_u_delta = math.log(math.expm1(1 - self.min_delta)) @property - def n_parameters(self) -> int: - return 3 * self.n_bins - 1 + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((3 * self.n_bins - 1,)) - @property - def default_parameters(self) -> torch.Tensor: - default_u_x = torch.zeros(size=(self.n_bins,)) - default_u_y = torch.zeros(size=(self.n_bins,)) - default_u_d = torch.zeros(size=(self.n_bins - 1,)) - return torch.cat([default_u_x, default_u_y, default_u_d], dim=0) def compute_bins(self, u, minimum, maximum): bin_sizes = torch.softmax(u, dim=-1) diff --git a/test/constants.py b/test/constants.py index a4a1b96..381c918 100644 --- a/test/constants.py +++ b/test/constants.py @@ -1,8 +1,11 @@ __test_constants = { 'batch_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], 'event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], + 'image_shape': [(4, 4, 3), (20, 20, 3), (10, 20, 3), (200, 200, 3), (20, 20, 1), (10, 20, 1)], 'context_shape': [None, (2,), (3,), (2, 4), (5,)], 'input_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], - 'n_predicted_parameters': [1, 2, 10, 50, 100] + 'n_predicted_parameters': [1, 2, 10, 50, 100], + 'predicted_parameter_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], + 'parameter_shape_per_element': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], } diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index e68514d..79ac5fc 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -2,29 +2,68 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ( - MADE, FeedForward, LinearMADE, ResidualFeedForward, Constant, Linear + MADE, FeedForward, LinearMADE, ResidualFeedForward, Constant, Linear, ConditionerTransform ) from test.constants import __test_constants @pytest.mark.parametrize('transform_class', [ MADE, - FeedForward, LinearMADE, - ResidualFeedForward, - Linear ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('input_event_shape', __test_constants['input_event_shape']) @pytest.mark.parametrize('output_event_shape', __test_constants['output_event_shape']) -@pytest.mark.parametrize('n_predicted_parameters', __test_constants['n_predicted_parameters']) -def test_shape(transform_class, batch_shape, input_event_shape, output_event_shape, n_predicted_parameters): +@pytest.mark.parametrize('parameter_shape_per_element', __test_constants['parameter_shape_per_element']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +def test_autoregressive(transform_class, + batch_shape, + input_event_shape, + output_event_shape, + parameter_shape_per_element, + context_shape): torch.manual_seed(0) x = torch.randn(size=(*batch_shape, *input_event_shape)) - transform = transform_class( + transform: ConditionerTransform = transform_class( input_event_shape=input_event_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + parameter_shape_per_element=parameter_shape_per_element, + context_shape=context_shape, + ) + + if context_shape is not None: + c = torch.randn(size=(*batch_shape, *context_shape)) + out = transform(x, c) + else: + out = transform(x) + assert out.shape == (*batch_shape, *output_event_shape, *parameter_shape_per_element) + + +@pytest.mark.parametrize('transform_class', [ + FeedForward, + ResidualFeedForward, + Linear +]) +@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) +@pytest.mark.parametrize('input_event_shape', __test_constants['input_event_shape']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +@pytest.mark.parametrize('predicted_parameter_shape', __test_constants['predicted_parameter_shape']) +def test_neural_network(transform_class, + batch_shape, + input_event_shape, + context_shape, + predicted_parameter_shape): + torch.manual_seed(0) + x = torch.randn(size=(*batch_shape, *input_event_shape)) + transform: ConditionerTransform = transform_class( + input_event_shape=input_event_shape, + context_shape=context_shape, + parameter_shape=predicted_parameter_shape ) - out = transform(x) - assert out.shape == (*batch_shape, *output_event_shape, n_predicted_parameters) + + if context_shape is not None: + c = torch.randn(size=(*batch_shape, *context_shape)) + out = transform(x, c) + else: + out = transform(x) + assert out.shape == (*batch_shape, *predicted_parameter_shape) diff --git a/test/test_lu_matrix_transformer.py b/test/test_lu_matrix_transformer.py new file mode 100644 index 0000000..ee531c0 --- /dev/null +++ b/test/test_lu_matrix_transformer.py @@ -0,0 +1,21 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.matrix import LUTransformer + + +def test_basic(): + torch.manual_seed(0) + + batch_shape = (2, 3) + event_shape = (5, 7) + + transformer = LUTransformer(event_shape) + + x = torch.randn(size=(*batch_shape, *event_shape)) + h = torch.randn(size=(*batch_shape, *transformer.parameter_shape)) + + z, log_det_forward = transformer.forward(x, h) + x_reconstructed, log_det_inverse = transformer.inverse(z, h) + + assert torch.allclose(x, x_reconstructed, atol=1e-3), f"{torch.linalg.norm(x-x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-3) diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 3cdaf43..352cd69 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -3,15 +3,14 @@ import pytest import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear_rational import \ LinearRational as LinearRationalSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import \ RationalQuadratic as RationalQuadraticSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.spline.cubic import Cubic as CubicSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.spline.basis import Basis as BasisSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Affine, Scale, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.convolution import Invertible1x1Convolution +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Scale, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import Sigmoid, DeepSigmoid, \ DenseSigmoid, DeepDenseSigmoid from normalizing_flows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ @@ -20,16 +19,16 @@ from test.constants import __test_constants -def setup_transformer_data(transformer_class: Transformer, batch_shape, event_shape): +def setup_transformer_data(transformer_class: ScalarTransformer, batch_shape, event_shape): # vector_to_vector: does the transformer map a vector to vector? Otherwise, it maps a scalar to scalar. torch.manual_seed(0) transformer = transformer_class(event_shape) x = torch.randn(*batch_shape, *event_shape) - h = torch.randn(*batch_shape, *event_shape, transformer.n_parameters) + h = torch.randn(*batch_shape, *transformer.parameter_shape) return transformer, x, h -def assert_valid_reconstruction(transformer: Transformer, +def assert_valid_reconstruction(transformer: ScalarTransformer, x: torch.Tensor, h: torch.Tensor, reconstruction_eps: float = 1e-3, @@ -66,7 +65,7 @@ def assert_valid_reconstruction(transformer: Transformer, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_affine(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_affine(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -80,7 +79,7 @@ def test_affine(transformer_class: Transformer, batch_shape: Tuple, event_shape: ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_spline(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_spline(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -90,7 +89,7 @@ def test_spline(transformer_class: Transformer, batch_shape: Tuple, event_shape: ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_integration(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_integration(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -98,7 +97,7 @@ def test_integration(transformer_class: Transformer, batch_shape: Tuple, event_s @pytest.mark.parametrize('transformer_class', [Sigmoid, DeepSigmoid]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_combination_basic(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -109,6 +108,33 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_combination_vector_to_vector(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) + + +@pytest.mark.parametrize("batch_size", [2, 3, 5, 7, 1]) +@pytest.mark.parametrize('image_shape', __test_constants['image_shape']) +def test_convolution(batch_size: int, image_shape: Tuple): + torch.manual_seed(0) + transformer = Invertible1x1Convolution(image_shape) + + *image_dimensions, n_channels = image_shape + + images = torch.randn(size=(batch_size, *image_shape)) + parameters = torch.randn(size=(batch_size, *image_dimensions, *transformer.parameter_shape)) + latent_images, log_det_forward = transformer.forward(images, parameters) + reconstructed_images, log_det_inverse = transformer.inverse(latent_images, parameters) + + assert log_det_forward.shape == (batch_size,) + assert log_det_inverse.shape == (batch_size,) + assert torch.isfinite(log_det_forward).all() + assert torch.isfinite(log_det_inverse).all() + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-3) + + assert latent_images.shape == images.shape + assert reconstructed_images.shape == images.shape + assert torch.isfinite(latent_images).all() + assert torch.isfinite(reconstructed_images).all() + rec_err = torch.max(torch.abs(images - reconstructed_images)) + assert torch.allclose(images, reconstructed_images, atol=1e-2), f"{rec_err = }" diff --git a/test/test_spline.py b/test/test_spline.py index eea1604..55dd02b 100644 --- a/test/test_spline.py +++ b/test/test_spline.py @@ -12,7 +12,7 @@ def test_linear_rational(): torch.manual_seed(0) x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) spline = LinearRational(event_shape=(1,)) - h = torch.randn(size=(len(x), spline.n_parameters)) + h = torch.randn(size=(len(x), *spline.parameter_shape_per_element)) z, log_det_forward = spline.forward(x, h) xr, log_det_inverse = spline.inverse(z, h) assert x.shape == z.shape == xr.shape @@ -36,7 +36,7 @@ def test_1d_spline(spline_class): [4.0], [-3.6] ]) - h = torch.randn(size=(3, 1, spline.n_parameters)) + h = torch.randn(size=(3, 1, *spline.parameter_shape_per_element)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -67,7 +67,7 @@ def test_2d_spline(spline_class): [4.0, 6.0], [-3.6, 0.7] ]) - h = torch.randn(size=(*batch_shape, *event_shape, spline.n_parameters)) + h = torch.randn(size=(*batch_shape, *spline.parameter_shape)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -95,7 +95,7 @@ def test_spline_exhaustive(spline_class, boundary: float, batch_shape, event_sha spline = spline_class(event_shape=event_shape, n_bins=8, boundary=boundary) x = torch.randn(size=(*batch_shape, *event_shape)) - h = torch.randn(size=(*batch_shape, *event_shape, spline.n_parameters)) + h = torch.randn(size=(*batch_shape, *spline.parameter_shape)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -119,7 +119,7 @@ def test_rq_spline(n_data, n_dim, n_bins, scale): spline = RationalQuadratic(event_shape=torch.Size((n_dim,)), n_bins=n_bins) x = torch.randn(n_data, n_dim) * scale - h = torch.randn(n_data, n_dim, spline.n_parameters) + h = torch.randn(n_data, n_dim, *spline.parameter_shape_per_element) z, log_det_forward = spline.forward(x, h) assert z.shape == x.shape diff --git a/test/test_umnn.py b/test/test_umnn.py index 4aaa322..c6fe481 100644 --- a/test/test_umnn.py +++ b/test/test_umnn.py @@ -15,7 +15,7 @@ def test_umnn(batch_shape: Tuple, event_shape: Tuple): x = torch.randn(*batch_shape, *event_shape) / 100 transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*batch_shape, *event_shape, len(transformer.default_parameters)) + h = torch.randn(size=(*batch_shape, *transformer.parameter_shape)) z, log_det_forward = transformer.forward(x, h) xr, log_det_inverse = transformer.inverse(z, h) @@ -36,7 +36,7 @@ def test_umnn_forward(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -55,7 +55,7 @@ def test_umnn_inverse(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_inverse0 = transformer.inverse(x0, h) z1, log_det_inverse1 = transformer.inverse(x1, h) @@ -74,7 +74,7 @@ def test_umnn_reconstruction(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -106,7 +106,7 @@ def test_umnn_forward_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -127,7 +127,7 @@ def test_umnn_inverse_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_inverse0 = transformer.inverse(x0, h) z1, log_det_inverse1 = transformer.inverse(x1, h) @@ -148,7 +148,7 @@ def test_umnn_reconstruction_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h)