diff --git a/normalizing_flows/__init__.py b/normalizing_flows/__init__.py index fc92a0a..1f268b5 100644 --- a/normalizing_flows/__init__.py +++ b/normalizing_flows/__init__.py @@ -1,4 +1,4 @@ -from normalizing_flows.flows import Flow +from normalizing_flows.flows import Flow, FlowMixture from normalizing_flows.bijections import ( NICE, RealNVP, diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 191a022..58921c0 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -25,3 +25,13 @@ Radial, Sylvester ) + +from normalizing_flows.bijections.finite.multiscale.architectures import ( + MultiscaleRealNVP, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + # MultiscaleDeepSigmoid, # TODO stabler initialization + # MultiscaleDenseSigmoid, # TODO stabler initialization + # MultiscaleDeepDenseSigmoid # TODO stabler initialization +) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a05fe5a..8cb358d 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -11,7 +11,8 @@ class Bijection(nn.Module): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): """ Bijection class. """ @@ -19,6 +20,7 @@ def __init__(self, self.event_shape = event_shape self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) self.context_shape = context_shape + self.transformed_shape = self.event_shape # Overwritten in multiscale flows TODO make into property def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -80,6 +82,8 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor assert log_dets.shape == batch_shape return outputs, log_dets + def regularization(self): + return 0.0 def invert(bijection): """ @@ -93,7 +97,8 @@ class BijectiveComposition(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): super().__init__(event_shape=event_shape, context_shape=context_shape) self.layers = nn.ModuleList(layers) @@ -112,3 +117,6 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tu log_det += log_det_layer x = z return x, log_det + + def regularization(self): + return sum([layer.regularization() for layer in self.layers]) diff --git a/normalizing_flows/bijections/continuous/otflow.py b/normalizing_flows/bijections/continuous/otflow.py index 5021504..976cbd9 100644 --- a/normalizing_flows/bijections/continuous/otflow.py +++ b/normalizing_flows/bijections/continuous/otflow.py @@ -40,11 +40,11 @@ def __init__(self, event_size: int, hidden_size: int, step_size: float = 0.01): @property def K0(self): - return torch.eye(*self.K0_delta.shape) + self.K0_delta / 1000 + return torch.eye(*self.K0_delta.shape).to(self.K0_delta) + self.K0_delta / 1000 @property def K1(self): - return torch.eye(*self.K1_delta.shape) + self.K1_delta / 1000 + return torch.eye(*self.K1_delta.shape).to(self.K1_delta) + self.K1_delta / 1000 @staticmethod def sigma(x): @@ -113,7 +113,7 @@ def hessian_trace(self, # Compute the first term in Equation 14 - ones = torch.ones(size=(self.K0.shape[1] - 1,)) + ones = torch.ones(size=(self.K0.shape[1] - 1,)).to(s) t0 = torch.sum( torch.multiply( @@ -146,7 +146,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs): # hidden_size = m if hidden_size is None: - hidden_size = max(math.log(event_size), 4) + hidden_size = max(int(math.log(event_size)), 4) r = min(10, event_size) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index ca31f0a..d2ee0ef 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -1,3 +1,5 @@ +from typing import Tuple, List, Type, Union + from normalizing_flows.bijections.finite.autoregressive.layers import ( ShiftCoupling, AffineCoupling, @@ -11,51 +13,64 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive, ElementwiseShift + LRSForwardMaskedAutoregressive ) from normalizing_flows.bijections.base import BijectiveComposition +from normalizing_flows.bijections.finite.autoregressive.layers_base import CouplingBijection, \ + MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection from normalizing_flows.bijections.finite.linear import ReversePermutation +def make_basic_layers(base_bijection: Type[ + Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None): + """ + Returns a list of bijections for transformations of vectors. + """ + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list)) + bijections.append(ElementwiseAffine(event_shape=event_shape)) + return bijections + + class NICE(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - ShiftCoupling(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) class RealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineCoupling(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) class InverseRealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - InverseAffineCoupling(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -67,13 +82,7 @@ class MAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -81,27 +90,19 @@ class IAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineInverseMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineInverseMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) class CouplingRQNSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSCoupling(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -113,27 +114,19 @@ class MaskedAutoregressiveRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) class CouplingLRS(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseShift(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - LRSCoupling(event_shape=event_shape) - ]) - bijections.append(ElementwiseShift(event_shape=event_shape)) + bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -141,13 +134,7 @@ class MaskedAutoregressiveLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseShift(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - LRSForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseShift(event_shape=event_shape)) + bijections = make_basic_layers(LRSForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -155,27 +142,19 @@ class InverseAutoregressiveRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSInverseMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSInverseMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -183,11 +162,5 @@ class UMNNMAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 1, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - UMNNMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(UMNNMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 04ddc8f..08b3d61 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -1,44 +1,90 @@ +from typing import Tuple, List + import torch -class CouplingMask: +class PartialCoupling: """ - Base object which holds coupling partition mask information. + Coupling mask object where a part of dimensions is kept unchanged and does not affect other dimensions. """ - def __init__(self, event_shape): + def __init__(self, + event_shape, + source_mask: torch.Tensor, + target_mask: torch): + """ + Partial coupling mask constructor. + + :param source_mask: boolean mask tensor of dimensions that affect target dimensions. This tensor has shape + event_shape. + :param target_mask: boolean mask tensor of affected dimensions. This tensor has shape event_shape. + """ self.event_shape = event_shape + self.source_mask = source_mask + self.target_mask = target_mask + self.event_size = int(torch.prod(torch.as_tensor(self.event_shape))) @property - def mask(self): - raise NotImplementedError + def ignored_event_size(self): + # Event size of ignored dimensions. + return torch.sum(1 - (self.source_mask + self.target_mask)) @property - def constant_event_size(self): - raise NotImplementedError + def source_event_size(self): + return int(torch.sum(self.source_mask)) @property - def transformed_event_size(self): - raise NotImplementedError + def target_event_size(self): + return int(torch.sum(self.target_mask)) -class HalfSplit(CouplingMask): - def __init__(self, event_shape): - super().__init__(event_shape) - self.event_partition_mask = torch.less( - torch.arange(self.event_size).view(*self.event_shape), - self.constant_event_size - ) +class Coupling(PartialCoupling): + """ + Base object which holds coupling partition mask information. + """ - @property - def constant_event_size(self): - return self.event_size // 2 + def __init__(self, event_shape, mask: torch.Tensor): + super().__init__(event_shape, source_mask=mask, target_mask=~mask) @property - def transformed_event_size(self): - return self.event_size - self.constant_event_size + def ignored_event_size(self): + return 0 - @property - def mask(self): - return self.event_partition_mask + +class GraphicalCoupling(PartialCoupling): + def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): + if len(event_shape) != 1: + raise ValueError("GraphicalCoupling is currently only implemented for vector data") + + source_dims = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_dims = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + + event_size = int(torch.prod(torch.as_tensor(event_shape))) + source_mask = torch.isin(torch.arange(event_size), source_dims) + target_mask = torch.isin(torch.arange(event_size), target_dims) + + super().__init__(event_shape, source_mask, target_mask) + + +class HalfSplit(Coupling): + def __init__(self, event_shape): + event_size = int(torch.prod(torch.as_tensor(event_shape))) + super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) + + +def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling_type: str = 'half_split', **kwargs): + """ + + :param event_shape: + :param coupling_type: one of ['half_split', 'checkerboard', 'checkerboard_inverted', 'channel_wise', + 'channel_wise_inverted']. + :param edge_list: + :return: + """ + if edge_list is not None: + return GraphicalCoupling(event_shape, edge_list) + elif coupling_type == 'half_split': + return HalfSplit(event_shape) + else: + raise ValueError diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index 74f3966..be50ef5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn @@ -26,7 +26,8 @@ def __init__(self, parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, global_parameter_mask: torch.Tensor = None, - initial_global_parameter_value: float = None): + initial_global_parameter_value: float = None, + **kwargs): """ :param input_event_shape: shape of conditioner input tensor x. :param context_shape: shape of conditioner context tensor c. @@ -96,6 +97,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError + def regularization(self): + return sum([torch.sum(torch.square(p)) for p in self.parameters()]) + class Constant(ConditionerTransform): def __init__(self, event_shape, parameter_shape, fill_value: float = None): @@ -204,6 +208,7 @@ def __init__(self, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, + nonlinearity: Type[nn.Module] = nn.Tanh, **kwargs): super().__init__( input_event_shape=input_event_shape, @@ -219,9 +224,9 @@ def __init__(self, if n_layers == 1: layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: - layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nonlinearity()]) for _ in range(n_layers - 2): - layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(n_hidden, n_hidden), nonlinearity()]) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError @@ -239,15 +244,15 @@ def __init__(self, *args, **kwargs): class ResidualFeedForward(ConditionerTransform): class ResidualBlock(nn.Module): - def __init__(self, event_size: int, hidden_size: int, block_size: int): + def __init__(self, event_size: int, hidden_size: int, block_size: int, nonlinearity: Type[nn.Module]): super().__init__() if block_size < 2: raise ValueError(f"block_size must be at least 2 but found {block_size}. " f"For block_size = 1, use the FeedForward class instead.") layers = [] - layers.extend([nn.Linear(event_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(event_size, hidden_size), nonlinearity()]) for _ in range(block_size - 2): - layers.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(hidden_size, hidden_size), nonlinearity()]) layers.extend([nn.Linear(hidden_size, event_size)]) self.sequential = nn.Sequential(*layers) @@ -261,6 +266,7 @@ def __init__(self, n_hidden: int = None, n_layers: int = 3, block_size: int = 2, + nonlinearity: Type[nn.Module] = nn.ReLU, **kwargs): super().__init__( input_event_shape=input_event_shape, @@ -275,12 +281,76 @@ def __init__(self, if n_layers <= 2: raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") - layers = [nn.Linear(self.n_input_event_dims, n_hidden), nn.ReLU()] + layers = [nn.Linear(self.n_input_event_dims, n_hidden), nonlinearity()] for _ in range(n_layers - 2): - layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size)) + layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size, nonlinearity=nonlinearity)) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): return self.sequential(self.context_combiner(x, context)) + + +class CombinedConditioner(nn.Module): + """ + Class that uses two different conditioners (each acting on different dimensions) to predict transformation + parameters. Transformation parameters are combined in a single vector. + """ + + def __init__(self, + conditioner1: ConditionerTransform, + conditioner2: ConditionerTransform, + conditioner1_input_mask: torch.Tensor, + conditioner2_input_mask: torch.Tensor): + super().__init__() + self.conditioner1 = conditioner1 + self.conditioner2 = conditioner2 + self.mask1 = conditioner1_input_mask + self.mask2 = conditioner2_input_mask + + def forward(self, x: torch.Tensor, context: torch.Tensor = None): + h1 = self.conditioner1(x[..., self.mask1], context) + h2 = self.conditioner1(x[..., self.mask1], context) + return h1 + h2 + + def regularization(self): + return self.conditioner1.regularization() + self.conditioner2.regularization() + + +class RegularizedCombinedConditioner(CombinedConditioner): + def __init__(self, + conditioner1: ConditionerTransform, + conditioner2: ConditionerTransform, + conditioner1_input_mask: torch.Tensor, + conditioner2_input_mask: torch.Tensor, + regularization_coefficient_1: float, + regularization_coefficient_2: float): + super().__init__( + conditioner1, + conditioner2, + conditioner1_input_mask, + conditioner2_input_mask + ) + self.c1 = regularization_coefficient_1 + self.c2 = regularization_coefficient_2 + + def regularization(self): + return self.c1 * self.conditioner1.regularization() + self.c2 * self.conditioner2.regularization() + + +class RegularizedGraphicalConditioner(RegularizedCombinedConditioner): + def __init__(self, + interacting_dimensions_conditioner: ConditionerTransform, + auxiliary_dimensions_conditioner: ConditionerTransform, + interacting_dimensions_mask: torch.Tensor, + auxiliary_dimensions_mask: torch.Tensor, + coefficient: float = 0.1): + super().__init__( + interacting_dimensions_conditioner, + auxiliary_dimensions_conditioner, + interacting_dimensions_mask, + auxiliary_dimensions_mask, + regularization_coefficient_1=0.0, + regularization_coefficient_2=coefficient + ) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index f5ba2c7..85e3bf1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,7 +1,9 @@ +from typing import Tuple, List + import torch from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift @@ -16,7 +18,6 @@ from normalizing_flows.bijections.base import invert -# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files. class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) @@ -45,52 +46,64 @@ class AffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): if event_shape == (1,): raise ValueError - coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): if event_shape == (1,): raise ValueError - coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)))) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class ShiftCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling_mask = HalfSplit(event_shape) - transformer = Shift(event_shape=torch.Size((coupling_mask.transformed_event_size,))) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class LRSCoupling(CouplingBijection): @@ -98,17 +111,21 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): assert n_bins >= 1 - coupling_mask = HalfSplit(event_shape) - transformer = LinearRational(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class RQSCoupling(CouplingBijection): @@ -116,16 +133,20 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling_mask = HalfSplit(event_shape) - transformer = RationalQuadratic(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class DSCoupling(CouplingBijection): @@ -133,21 +154,25 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling_mask = HalfSplit(event_shape) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = DeepSigmoid( - event_shape=torch.Size((coupling_mask.transformed_event_size,)), + event_shape=torch.Size((coupling.target_event_size,)), n_hidden_layers=n_hidden_layers ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class LinearAffineCoupling(AffineCoupling): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 6562dda..fe20d30 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,11 +1,11 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \ MADE -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import CouplingMask +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape @@ -31,6 +31,9 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. x, log_det = self.transformer.inverse(z, h) return x, log_det + def regularization(self): + return self.conditioner_transform.regularization() if self.conditioner_transform is not None else 0.0 + class CouplingBijection(AutoregressiveBijection): """ @@ -52,14 +55,20 @@ class CouplingBijection(AutoregressiveBijection): def __init__(self, transformer: TensorTransformer, - coupling_mask: CouplingMask, + coupling: PartialCoupling, conditioner_transform: ConditionerTransform, **kwargs): - super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs) - self.coupling_mask = coupling_mask + super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + self.coupling = coupling + + def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: + return x[..., self.coupling.source_mask] + + def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: + return x[..., self.coupling.target_mask] - assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,) - assert transformer.event_shape == (self.coupling_mask.transformed_event_size,) + def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): + x[..., self.coupling.target_mask] = x_transformed def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): """ @@ -70,20 +79,28 @@ def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tenso :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). """ # Predict transformer parameters for output dimensions - x_a = x[..., self.coupling_mask.mask] # (*b, constant_event_size) + x_a = self.get_constant_part(x) # (*b, constant_event_size) h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() h_b = self.partition_and_predict_parameters(x, context) - z[..., ~self.coupling_mask.mask], log_det = self.transformer.forward(x[..., ~self.coupling_mask.mask], h_b) + z_transformed, log_det = self.transformer.forward( + self.get_transformed_part(x), + h_b + ) + self.set_transformed_part(z, z_transformed) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() h_b = self.partition_and_predict_parameters(x, context) - x[..., ~self.coupling_mask.mask], log_det = self.transformer.inverse(z[..., ~self.coupling_mask.mask], h_b) + x_transformed, log_det = self.transformer.inverse( + self.get_transformed_part(z), + h_b + ) + self.set_transformed_part(x, x_transformed) return x, log_det diff --git a/normalizing_flows/bijections/finite/multiscale/__init__.py b/normalizing_flows/bijections/finite/multiscale/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py new file mode 100644 index 0000000..dd6613d --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -0,0 +1,226 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational +from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import ( + DeepSigmoid, + DeepDenseSigmoid, + DenseSigmoid +) +from normalizing_flows.bijections import BijectiveComposition +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection, FactoredBijection + + +def check_image_shape_for_multiscale_flow(event_shape, n_layers): + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with 3 axes.") + if event_shape[1] % 2 != 0 or event_shape[2] % 2 != 0: + raise ValueError("Image height and width must be divisible by 2.") + if n_layers is not None and n_layers < 1: + raise ValueError("Need at least one layer for multiscale flow.") + + # Check image height and width + if n_layers is not None: + if event_shape[1] % (2 ** n_layers) != 0: + raise ValueError("Image height must be divisible by pow(2, n_layers).") + elif event_shape[2] % (2 ** n_layers) != 0: + raise ValueError("Image width must be divisible by pow(2, n_layers).") + + +def automatically_determine_n_layers(event_shape): + if event_shape[1] % (2 ** 3) == 0 and event_shape[2] % (2 ** 3) == 0: + # Try using 3 layers + n_layers = 3 + elif event_shape[1] % (2 ** 2) == 0 and event_shape[2] % (2 ** 2) == 0: + # Try using 2 layers + n_layers = 2 + elif event_shape[1] % 2 == 0 and event_shape[2] % 2 == 0: + n_layers = 1 + else: + raise ValueError("Image height and width must be divisible by 2.") + return n_layers + + +def make_factored_image_layers(event_shape, + transformer_class, + n_layers: int = None): + """ + Creates a list of image transformations consisting of coupling layers and squeeze layers. + After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore). + + :param event_shape: (c, 2^n, 2^m). + :param transformer_class: + :param n_layers: + :return: + """ + check_image_shape_for_multiscale_flow(event_shape, n_layers) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + + def recursive_layer_builder(event_shape_, n_layers_): + msb = MultiscaleBijection( + input_event_shape=event_shape_, + transformer_class=transformer_class + ) + if n_layers_ == 1: + return msb + + c, h, w = msb.transformed_shape # c is a multiple of 4 after squeezing + + small_bijection_shape = (c // 2, h, w) + small_bijection_mask = (torch.arange(c) >= c // 2)[:, None, None].repeat(1, h, w) + fb = FactoredBijection( + event_shape=(c, h, w), + small_bijection=recursive_layer_builder( + event_shape_=small_bijection_shape, + n_layers_=n_layers_ - 1 + ), + small_bijection_mask=small_bijection_mask + ) + composition = BijectiveComposition( + event_shape=msb.event_shape, + layers=[msb, fb] + ) + composition.transformed_shape = fb.transformed_shape + return composition + + bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections.append(recursive_layer_builder(bijections[-1].transformed_shape, n_layers)) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +def make_image_layers_non_factored(event_shape, + transformer_class, + n_layers: int = None): + """ + Returns a list of bijections for transformations of images with multiple channels. + + Let n be the number of layers. This sequence of bijections takes as input an image with shape (c, h, w) and outputs + an image with shape (4 ** n * c, h / 2 ** n, w / 2 ** n). We require h and w to be divisible by 2 ** n. + """ + check_image_shape_for_multiscale_flow(event_shape, n_layers) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers - 1): + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class + ) + ) + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class, + n_checkerboard_layers=4, + squeeze_layer=False, + n_channel_wise_layers=0 + ) + ) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +def make_image_layers(*args, factored: bool = False, **kwargs): + if factored: + return make_factored_image_layers(*args, **kwargs) + else: + return make_image_layers_non_factored(*args, **kwargs) + + +class MultiscaleRealNVP(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleNICE(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleRQNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleLRSNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDeepSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DeepSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDeepDenseSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DeepDenseSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDenseSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DenseSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py new file mode 100644 index 0000000..4298710 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -0,0 +1,258 @@ +from typing import Type, Union, Tuple + +import torch + +from normalizing_flows.bijections import BijectiveComposition +from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform +from normalizing_flows.bijections.base import Bijection +from normalizing_flows.bijections.finite.autoregressive.layers_base import CouplingBijection +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \ + ChannelWiseHalfSplit +from normalizing_flows.neural_networks.convnet import ConvNet +from normalizing_flows.utils import get_batch_shape + + +class FactoredBijection(Bijection): + """ + Factored bijection class. + + Partitions the input tensor x into parts x_A and x_B, then applies a bijection to x_A independently of x_B while + keeping x_B identical. + """ + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + small_bijection: Bijection, + small_bijection_mask: torch.Tensor, + **kwargs): + """ + + :param event_shape: shape of input event x. + :param small_bijection: bijection applied to transformed event x_A. + :param small_bijection_mask: boolean mask that selects which elements of event x correspond to the transformed + event x_A. + :param kwargs: + """ + super().__init__(event_shape, **kwargs) + + # Check that shapes are correct + event_size = torch.prod(torch.as_tensor(event_shape)) + transformed_event_size = torch.prod(torch.as_tensor(small_bijection.event_shape)) + assert event_size >= transformed_event_size + + assert small_bijection_mask.shape == event_shape + + self.transformed_event_mask = small_bijection_mask + self.small_bijection = small_bijection + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(x, self.event_shape) + transformed, log_det = self.small_bijection.forward( + x[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.event_shape), + context + ) + out = x.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(z, self.event_shape) + transformed, log_det = self.small_bijection.inverse( + z[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.transformed_shape), + context + ) + out = z.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + +class ConvNetConditioner(ConditionerTransform): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + kernels: Tuple[int, ...] = None, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=None, + parameter_shape=parameter_shape, + **kwargs + ) + self.network = ConvNet( + input_shape=input_event_shape, + n_outputs=self.n_transformer_parameters, + kernels=kernels, + ) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.network(x) + + +class ConvolutionalCouplingBijection(CouplingBijection): + def __init__(self, + transformer: TensorTransformer, + coupling: Union[Checkerboard, ChannelWiseHalfSplit], + **kwargs): + conditioner_transform = ConvNetConditioner( + input_event_shape=coupling.constant_shape, + parameter_shape=transformer.parameter_shape, + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) + self.coupling = coupling + + def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: + """ + + :param x: tensor with shape (*b, channels, height, width). + :return: tensor with shape (*b, constant_channels, constant_height, constant_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.source_mask].view(*batch_shape, *self.coupling.constant_shape) + + def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: + """ + + :param x: tensor with shape (*b, channels, height, width). + :return: tensor with shape (*b, transformed_channels, transformed_height, constant_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + + def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): + """ + + :param x: tensor with shape (*b, channels, height, width). + :param x_transformed: tensor with shape (*b, transformed_channels, transformed_height, transformed_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): + batch_shape = get_batch_shape(x, self.event_shape) + super_out = super().partition_and_predict_parameters(x, context) + return super_out.view(*batch_shape, *self.coupling.transformed_shape, + *self.transformer.parameter_shape_per_element) + + +class CheckerboardCoupling(ConvolutionalCouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' + ) + transformer = transformer_class(event_shape=coupling.transformed_shape) + super().__init__(transformer, coupling, **kwargs) + + +class ChannelWiseCoupling(ConvolutionalCouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' + ) + transformer = transformer_class(event_shape=coupling.transformed_shape) + super().__init__(transformer, coupling, **kwargs) + + +class Squeeze(Bijection): + """ + Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape + (*batch_shape, 4 * channels, height / 2, width / 2). + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + # Check shape length + if len(event_shape) != 3: + raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") + # Check that height and width are divisible by two + if event_shape[1] % 2 != 0: + raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") + if event_shape[2] % 2 != 0: + raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") + super().__init__(event_shape, **kwargs) + c, h, w = event_shape + self.transformed_event_shape = torch.Size((4 * c, h // 2, w // 2)) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape + (*batch_shape, 4 * channels, height // 2, width // 2). + """ + batch_shape = get_batch_shape(x, self.event_shape) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + + channels, height, width = x.shape[-3:] + assert height % 2 == 0 + assert width % 2 == 0 + + out = torch.concatenate([ + x[..., ::2, ::2], + x[..., ::2, 1::2], + x[..., 1::2, ::2], + x[..., 1::2, 1::2] + ], dim=-3) + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape + (*batch_shape, channels, height, width). + """ + batch_shape = get_batch_shape(z, self.transformed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + out[..., ::2, ::2] = z[..., 0:channels, :, :] + out[..., ::2, 1::2] = z[..., channels:2 * channels, :, :] + out[..., 1::2, ::2] = z[..., 2 * channels:3 * channels, :, :] + out[..., 1::2, 1::2] = z[..., 3 * channels:4 * channels, :, :] + return out, log_det + + +class MultiscaleBijection(BijectiveComposition): + def __init__(self, + input_event_shape, + transformer_class: Type[TensorTransformer], + n_checkerboard_layers: int = 3, + n_channel_wise_layers: int = 3, + use_squeeze_layer: bool = True, + **kwargs): + checkerboard_layers = [ + CheckerboardCoupling( + input_event_shape, + transformer_class, + alternate=i % 2 == 1 + ) + for i in range(n_checkerboard_layers) + ] + squeeze_layer = Squeeze(input_event_shape) + channel_wise_layers = [ + ChannelWiseCoupling( + squeeze_layer.transformed_event_shape, + transformer_class, + alternate=i % 2 == 1 + ) + for i in range(n_channel_wise_layers) + ] + if use_squeeze_layer: + layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] + else: + layers = [*checkerboard_layers, *channel_wise_layers] + super().__init__(input_event_shape, layers, **kwargs) + self.transformed_shape = squeeze_layer.transformed_event_shape if use_squeeze_layer else input_event_shape diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py new file mode 100644 index 0000000..33bf6cd --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -0,0 +1,79 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling + + +class Checkerboard(Coupling): + """ + Checkerboard coupling for image data. + """ + + def __init__(self, event_shape, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param invert: invert the checkerboard mask. + """ + channels, height, width = event_shape + mask = (torch.arange(height * width) % 2).view(height, width).bool() + mask = mask[None].repeat(channels, 1, 1) # (channels, height, width) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + @property + def constant_shape(self): + n_channels, height, width = self.event_shape + return n_channels, height // 2, width # rectangular shape + + @property + def transformed_shape(self): + return self.constant_shape + + +class ChannelWiseHalfSplit(Coupling): + """ + Channel-wise coupling for image data. + """ + + def __init__(self, event_shape, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param invert: invert the checkerboard mask. + """ + n_channels, height, width = event_shape + mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) + mask = mask[:, None, None].repeat(1, height, width) # (channels, height, width) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + @property + def constant_shape(self): + n_channels, height, width = self.event_shape + return n_channels // 2, height, width + + @property + def transformed_shape(self): + n_channels, height, width = self.event_shape + return n_channels - n_channels // 2, height, width + + +def make_image_coupling(event_shape, coupling_type: str, **kwargs): + """ + + :param event_shape: + :param coupling_type: one of ['checkerboard', 'checkerboard_inverted', 'channel_wise', 'channel_wise_inverted']. + :return: + """ + if coupling_type == 'checkerboard': + return Checkerboard(event_shape, invert=False, **kwargs) + elif coupling_type == 'checkerboard_inverted': + return Checkerboard(event_shape, invert=True, **kwargs) + elif coupling_type == 'channel_wise': + return ChannelWiseHalfSplit(event_shape, invert=False) + elif coupling_type == 'channel_wise_inverted': + return ChannelWiseHalfSplit(event_shape, invert=True) + else: + raise ValueError diff --git a/normalizing_flows/bijections/finite/multiscale/layers.py b/normalizing_flows/bijections/finite/multiscale/layers.py new file mode 100644 index 0000000..7a95207 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/layers.py @@ -0,0 +1,7 @@ +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift + + +class MultiscaleAffineCoupling(MultiscaleBijection): + def __init__(self, input_event_shape, **kwargs): + super().__init__(input_event_shape, transformer_class=Affine, **kwargs) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 444e6f8..f50b1b8 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,33 +1,24 @@ from copy import deepcopy -from typing import Union, Tuple +from typing import Union, Tuple, List +import numpy as np import torch import torch.nn as nn -from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm from normalizing_flows.bijections.base import Bijection -from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection -from normalizing_flows.regularization import reconstruction_error -from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event, create_data_loader +from normalizing_flows.utils import flatten_event, unflatten_event, create_data_loader -class Flow(nn.Module): - """ - Normalizing flow class. - - This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). - A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. - """ - - def __init__(self, bijection: Bijection): - """ - - :param bijection: transformation component of the normalizing flow. - """ +class BaseFlow(nn.Module): + def __init__(self, event_shape): super().__init__() - self.register_module('bijection', bijection) - self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) - self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) + self.event_shape = event_shape + self.event_size = int(torch.prod(torch.as_tensor(event_shape))) + self.register_buffer('loc', torch.zeros(self.event_size)) + self.register_buffer('covariance_matrix', torch.eye(self.event_size)) + + def get_device(self): + return self.loc.device @property def base(self) -> torch.distributions.Distribution: @@ -43,7 +34,7 @@ def base_log_prob(self, z: torch.Tensor): :param z: input tensor. :return: log probability of the input tensor. """ - zf = flatten_event(z, self.bijection.event_shape) + zf = flatten_event(z, self.event_shape) log_prob = self.base.log_prob(zf) return log_prob @@ -55,65 +46,11 @@ def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]): :return: tensor with shape sample_shape. """ z_flat = self.base.sample(sample_shape) - z = unflatten_event(z_flat, self.bijection.event_shape) + z = unflatten_event(z_flat, self.event_shape) return z - def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Transform the input x to the space of the base distribution. - - :param x: input tensor. - :param context: context tensor upon which the transformation is conditioned. - :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the - transformation. - """ - if context is not None: - assert context.shape[0] == x.shape[0] - context = context.to(self.loc) - z, log_det = self.bijection.forward(x.to(self.loc), context=context) - log_base = self.base_log_prob(z) - return z, log_base + log_det - - def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Compute the logarithm of the probability density of input x according to the normalizing flow. - - :param x: input tensor. - :param context: context tensor. - :return: - """ - return self.forward_with_log_prob(x, context)[1] - - def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): - """ - Sample from the normalizing flow. - - If context given, sample n tensors for each context tensor. - Otherwise, sample n tensors. - - :param n: number of tensors to sample. - :param context: context tensor with shape c. - :param no_grad: if True, do not track gradients in the inverse pass. - :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. - """ - if context is not None: - z = self.base_sample(sample_shape=torch.Size((n, len(context)))) - context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape - assert z.shape[:2] == context.shape[:2] - else: - z = self.base_sample(sample_shape=torch.Size((n,))) - if no_grad: - z = z.detach() - with torch.no_grad(): - x, log_det = self.bijection.inverse(z, context=context) - else: - x, log_det = self.bijection.inverse(z, context=context) - x = x.to(self.loc) - - if return_log_prob: - log_prob = self.base_log_prob(z) + log_det - return x, log_prob - return x + def regularization(self): + return 0.0 def fit(self, x_train: torch.Tensor, @@ -152,10 +89,7 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ - self.bijection.train() - - # Compute the number of event dimensions - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) + self.train() # Set the default batch size if batch_size is None: @@ -169,7 +103,7 @@ def fit(self, "training", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) # Process validation data @@ -181,7 +115,7 @@ def fit(self, "validation", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) best_val_loss = torch.inf @@ -192,10 +126,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] batch_context = batch_[2] if len(batch_) == 3 else None - batch_log_prob = self.log_prob(batch_x.to(self.loc), context=batch_context) - batch_weights = batch_weights.to(self.loc) + batch_log_prob = self.log_prob(batch_x.to(self.get_device()), context=batch_context) + batch_weights = batch_weights.to(self.get_device()) assert batch_log_prob.shape == batch_weights.shape, f"{batch_log_prob.shape = }, {batch_weights.shape = }" - batch_loss = -reduction(batch_log_prob * batch_weights) / n_event_dims + batch_loss = -reduction(batch_log_prob * batch_weights) / self.event_size return batch_loss @@ -207,8 +141,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) - if hasattr(self.bijection, 'regularization'): - train_loss += self.bijection.regularization() + train_loss += self.regularization() train_loss.backward() optimizer.step() @@ -228,10 +161,9 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss val_loss = 0.0 for val_batch in val_loader: - n_batch_data = len(val_batch[0]) - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data - if hasattr(self.bijection, 'regularization'): - val_loss += self.bijection.regularization() + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) + val_loss /= len(x_val) + val_loss += self.regularization() # Check if validation loss is the lowest so far if val_loss < best_val_loss: @@ -251,7 +183,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) - self.bijection.eval() + self.eval() def variational_fit(self, target_log_prob: callable, @@ -260,8 +192,8 @@ def variational_fit(self, n_samples: int = 1000, show_progress: bool = False): """ - Train a normalizing flow with stochastic variational inference. - Stochastic variational inference lets us train a normalizing flow using the unnormalized target log density + Train a distribution with stochastic variational inference. + Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details @@ -276,89 +208,155 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ - iterator = tqdm(range(n_epochs), desc='Variational NF fit', disable=not show_progress) + iterator = tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) for _ in iterator: optimizer.zero_grad() flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob) - if hasattr(self.bijection, 'regularization'): - loss += self.bijection.regularization() + loss += self.regularization() loss.backward() optimizer.step() - iterator.set_postfix_str(f'Variational loss: {loss:.4f}') + iterator.set_postfix_str(f'Loss: {loss:.4f}') -class DDNF(Flow): +class Flow(BaseFlow): """ - Deep diffeomorphic normalizing flow. + Normalizing flow class. - Salman et al. Deep diffeomorphic normalizing flows (2018). + This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). + A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. """ - def __init__(self, event_shape: torch.Size, **kwargs): - bijection = DeepDiffeomorphicBijection(event_shape=event_shape, **kwargs) - super().__init__(bijection) + def __init__(self, bijection: Bijection): + """ - def fit(self, - x_train: torch.Tensor, - n_epochs: int = 500, - lr: float = 0.05, - batch_size: int = 1024, - shuffle: bool = True, - show_progress: bool = False, - w_train: torch.Tensor = None, - rec_err_coef: float = 1.0): + :param bijection: transformation component of the normalizing flow. """ + super().__init__(event_shape=bijection.event_shape) + self.register_module('bijection', bijection) - :param x_train: - :param n_epochs: - :param lr: learning rate. In general, lower learning rates are recommended for high-parametric bijections. - :param batch_size: - :param shuffle: - :param show_progress: - :param w_train: training data weights - :param rec_err_coef: reconstruction error regularization coefficient. + def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Transform the input x to the space of the base distribution. + + :param x: input tensor. + :param context: context tensor upon which the transformation is conditioned. + :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the + transformation. + """ + if context is not None: + assert context.shape[0] == x.shape[0] + context = context.to(self.get_device()) + z, log_det = self.bijection.forward(x.to(self.get_device()), context=context) + log_base = self.base_log_prob(z) + return z, log_base + log_det + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Compute the logarithm of the probability density of input x according to the normalizing flow. + + :param x: input tensor. + :param context: context tensor. :return: """ - if w_train is None: - batch_shape = get_batch_shape(x_train, self.bijection.event_shape) - w_train = torch.ones(batch_shape) - if batch_size is None: - batch_size = len(x_train) - optimizer = torch.optim.AdamW(self.parameters(), lr=lr) - dataset = TensorDataset(x_train, w_train) - data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + return self.forward_with_log_prob(x, context)[1] + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + """ + Sample from the normalizing flow. - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) + If context given, sample n tensors for each context tensor. + Otherwise, sample n tensors. + + :param n: number of tensors to sample. + :param context: context tensor with shape c. + :param no_grad: if True, do not track gradients in the inverse pass. + :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. + """ + + if context is not None: + sample_shape = torch.Size((n, len(context))) + z = self.base_sample(sample_shape=sample_shape) + context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape + assert z.shape[:2] == context.shape[:2] + else: + sample_shape = torch.Size((n,)) + z = self.base_sample(sample_shape=sample_shape) - if show_progress: - iterator = tqdm(range(n_epochs), desc='Fitting NF') + if no_grad: + z = z.detach() + with torch.no_grad(): + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) else: - iterator = range(n_epochs) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) + x = x.to(self.get_device()) - for _ in iterator: - for batch_x, batch_w in data_loader: - optimizer.zero_grad() + if return_log_prob: + log_prob = self.base_log_prob(z) + log_det + return x, log_prob + return x - z, log_prob = self.forward_with_log_prob(batch_x.to(self.loc)) # TODO context! - w = batch_w.to(self.loc) - assert log_prob.shape == w.shape - loss = -torch.mean(log_prob * w) / n_event_dims + def regularization(self): + if hasattr(self.bijection, 'regularization'): + return self.bijection.regularization() + else: + return 0.0 - if hasattr(self.bijection, 'regularization'): - # Always true for DeepDiffeomorphicBijection, but we keep it for clarity - loss += self.bijection.regularization() - # Inverse consistency regularization - x_reconstructed = self.bijection.inverse(z) - loss += reconstruction_error(batch_x, x_reconstructed, self.bijection.event_shape, rec_err_coef) +class FlowMixture(BaseFlow): + def __init__(self, flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False): + super().__init__(event_shape=flows[0].event_shape) - # Geodesic regularization + # Use uniform weights by default + if weights is None: + weights = [1.0 / len(flows)] * len(flows) - loss.backward() - optimizer.step() + assert len(weights) == len(flows) + assert all([w > 0.0 for w in weights]) + assert np.isclose(sum(weights), 1.0) - if show_progress: - iterator.set_postfix_str(f'Loss: {loss:.4f}') + self.flows = nn.ModuleList(flows) + if trainable_weights: + self.logit_weights = nn.Parameter(torch.log(torch.tensor(weights))) + else: + self.logit_weights = torch.log(torch.tensor(weights)) + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + flow_log_probs = torch.stack([flow.log_prob(x, context=context) for flow in self.flows]) + # (n_flows, *batch_shape) + + batch_shape = flow_log_probs.shape[1:] + log_weights_reshaped = self.logit_weights.view(-1, *([1] * len(batch_shape))) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # batch_shape + return log_prob + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + flow_samples = [] + flow_log_probs = [] + for flow in self.flows: + flow_x, flow_log_prob = flow.sample(n, context=context, no_grad=no_grad, return_log_prob=True) + flow_samples.append(flow_x) + flow_log_probs.append(flow_log_prob) + + flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) + categorical_samples = torch.distributions.Categorical(logits=self.logit_weights).sample( + sample_shape=torch.Size((n,)) + ) # (n,) + one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) + one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) + # (n_flows, n, *event_shape) + + samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) + + if return_log_prob: + flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) + log_weights_reshaped = self.logit_weights[:, None] # (n_flows, 1) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # (n,) + return samples, log_prob + else: + return samples + + def regularization(self): + return sum([flow.regularization() for flow in self.flows]) diff --git a/normalizing_flows/neural_networks/__init__.py b/normalizing_flows/neural_networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/neural_networks/convnet.py b/normalizing_flows/neural_networks/convnet.py new file mode 100644 index 0000000..484cb55 --- /dev/null +++ b/normalizing_flows/neural_networks/convnet.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +class ConvNet(nn.Module): + class ConvNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, input_height, input_width, use_pooling: bool = True): + super().__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) + self.bn = nn.BatchNorm2d(out_channels) + self.pool = nn.MaxPool2d(2) if use_pooling else nn.Identity() + + if use_pooling: + self.output_shape = (out_channels, input_height // 2, input_width // 2) + else: + self.output_shape = (out_channels, input_height, input_width) + + def forward(self, x): + return self.bn(self.pool(torch.relu(self.conv(x)))) + + def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None): + """ + + :param input_shape: (channels, height, width) + :param n_outputs: + """ + super().__init__() + channels, height, width = input_shape + + if kernels is None: + kernels = (4, 8, 16, 24) + else: + assert len(kernels) >= 1 + + blocks = [ + self.ConvNetBlock( + in_channels=channels, + out_channels=kernels[0], + input_height=height, + input_width=width, + use_pooling=min(height, width) >= 2 + ) + ] + for i in range(len(kernels) - 1): + blocks.append( + self.ConvNetBlock( + in_channels=kernels[i], + out_channels=kernels[i + 1], + input_height=blocks[i].output_shape[1], + input_width=blocks[i].output_shape[2], + use_pooling=min(blocks[i].output_shape[1], blocks[i].output_shape[2]) >= 2 + ) + ) + self.blocks = nn.ModuleList(blocks) + self.linear = nn.Linear( + in_features=int(torch.prod(torch.as_tensor(self.blocks[-1].output_shape))), + out_features=n_outputs + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.flatten(start_dim=1, end_dim=-1) + x = self.linear(x) + return x + + +if __name__ == '__main__': + torch.manual_seed(0) + image_shape = (1, 36, 29) + images = torch.randn(size=(11, *image_shape)) + net = ConvNet(input_shape=image_shape, n_outputs=77) + out = net(images) + print(f'{out.shape = }') diff --git a/normalizing_flows/neural_networks/resnet.py b/normalizing_flows/neural_networks/resnet.py new file mode 100644 index 0000000..9b6c8a9 --- /dev/null +++ b/normalizing_flows/neural_networks/resnet.py @@ -0,0 +1,145 @@ +# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, image_width, block, num_blocks, n_hidden=100, n_outputs=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear1 = nn.Linear(512 * block.expansion * (image_width // 32) ** 2, n_hidden) + self.linear2 = nn.Linear(n_hidden, n_outputs) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + """ + :param x: tensor with shape (*b, channels, height, width). Height and width must be equal. + :return: + """ + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.flatten(start_dim=1, end_dim=-1) + out = self.linear1(out) + out = self.linear2(out) + return out + + +def make_resnet18(image_width, n_outputs): + return ResNet(image_width, BasicBlock, [2, 2, 2, 2], n_outputs=n_outputs) + + +def make_resnet34(image_width, n_outputs): + return ResNet(image_width, BasicBlock, [3, 4, 6, 3], n_outputs=n_outputs) + + +def make_resnet50(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 4, 6, 3], n_outputs=n_outputs) + + +def make_resnet101(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 4, 23, 3], n_outputs=n_outputs) + + +def make_resnet152(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 8, 36, 3], n_outputs=n_outputs) + + +if __name__ == '__main__': + n_images = 2 + image_shape = (3, 32, 32) + + net = make_resnet18(image_width=image_shape[-1], n_outputs=15) + y = net(torch.randn(n_images, *image_shape)) + print(y.size()) diff --git a/test/test_channel_wise_coupling.py b/test/test_channel_wise_coupling.py new file mode 100644 index 0000000..48bda61 --- /dev/null +++ b/test/test_channel_wise_coupling.py @@ -0,0 +1,19 @@ +import torch + +from normalizing_flows.bijections.finite.multiscale.coupling import ChannelWiseHalfSplit + + +def test_partition_shapes_1(): + torch.manual_seed(0) + image_shape = (3, 4, 4) + coupling = ChannelWiseHalfSplit(image_shape, invert=True) + assert coupling.constant_shape == (1, 4, 4) + assert coupling.transformed_shape == (2, 4, 4) + + +def test_partition_shapes_2(): + torch.manual_seed(0) + image_shape = (3, 16, 16) + coupling = ChannelWiseHalfSplit(image_shape, invert=True) + assert coupling.constant_shape == (1, 16, 16) + assert coupling.transformed_shape == (2, 16, 16) diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py new file mode 100644 index 0000000..4c43cf7 --- /dev/null +++ b/test/test_factored_bijection.py @@ -0,0 +1,42 @@ +import torch +from normalizing_flows.bijections.finite.multiscale.base import FactoredBijection +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine + + +def test_basic(): + torch.manual_seed(0) + + bijection = FactoredBijection( + event_shape=(6, 6), + small_bijection_event_shape=(3, 3), + small_bijection_mask=torch.tensor([ + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]), + small_bijection=ElementwiseAffine(event_shape=(3, 3)) + ) + + x = torch.randn(100, *bijection.event_shape) + z, log_det_forward = bijection.forward(x) + + assert torch.allclose( + x[..., ~bijection.transformed_event_mask], + z[..., ~bijection.transformed_event_mask], + atol=1e-5 + ) + + assert torch.all( + ~torch.isclose( + x[..., bijection.transformed_event_mask], + z[..., bijection.transformed_event_mask], + atol=1e-5 + ) + ) + + xr, log_det_inverse = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-5) + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-5) diff --git a/test/test_graphical_normalizing_flow.py b/test/test_graphical_normalizing_flow.py new file mode 100644 index 0000000..0c465f8 --- /dev/null +++ b/test/test_graphical_normalizing_flow.py @@ -0,0 +1,91 @@ +import pytest +import torch +from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF + + +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_2d(architecture): + torch.manual_seed(0) + + n_data = 100 + n_dim = 2 + x = torch.randn(size=(n_data, n_dim)) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse) + + +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d(architecture): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-4) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_2(architecture): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-4) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_3(architecture): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + + +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_random(architecture): + torch.manual_seed(0) + + n_data = 100 + n_dim = 50 + x = torch.randn(size=(n_data, n_dim)) + + interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,))) + interacting_dimensions = interacting_dimensions[torch.randperm(len(interacting_dimensions))] + source_dimensions = interacting_dimensions[:len(interacting_dimensions) // 2] + target_dimensions = interacting_dimensions[len(interacting_dimensions) // 2:] + + edge_list = [] + for s in source_dimensions: + for t in target_dimensions: + edge_list.append((s, t)) + + bijection = architecture(event_shape=(n_dim,), edge_list=edge_list) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" diff --git a/test/test_identity_bijections.py b/test/test_identity_bijections.py new file mode 100644 index 0000000..fa88e93 --- /dev/null +++ b/test/test_identity_bijections.py @@ -0,0 +1,67 @@ +# Check that when all bijection parameters are set to 0, the bijections reduce to an identity map + +from normalizing_flows.bijections.finite.autoregressive.layers import ( + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + UMNNMaskedAutoregressive, + +) +import torch +import pytest + + +@pytest.mark.parametrize( + 'layer_class', + [ + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + # UMNNMaskedAutoregressive, # Inexact due to numerics + ] +) +def test_basic(layer_class): + n_batch, n_dim = 2, 3 + + torch.manual_seed(0) + x = torch.randn(size=(n_batch, n_dim)) + layer = layer_class(event_shape=torch.Size((n_dim,))) + + # Set all conditioner parameters to 0 + with torch.no_grad(): + for p in layer.parameters(): + p.data *= 0 + + assert torch.allclose(layer(x)[0], x, atol=1e-2) diff --git a/test/test_mixture.py b/test/test_mixture.py new file mode 100644 index 0000000..eed7ff0 --- /dev/null +++ b/test/test_mixture.py @@ -0,0 +1,69 @@ +from normalizing_flows.flows import FlowMixture, Flow +from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF +import torch + + +def test_basic(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 10 + x = torch.randn(size=(n_data, n_dim)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=(n_dim,))), + Flow(NICE(event_shape=(n_dim,))), + Flow(CouplingRQNSF(event_shape=(n_dim,))) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) + + +def test_medium(): + torch.manual_seed(0) + + n_data = 1000 + n_dim = 100 + x = torch.randn(size=(n_data, n_dim)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=(n_dim,))), + Flow(NICE(event_shape=(n_dim,))), + Flow(CouplingRQNSF(event_shape=(n_dim,))) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) + + +def test_complex_event(): + torch.manual_seed(0) + + n_data = 1000 + event_shape = (2, 3, 4, 5) + x = torch.randn(size=(n_data, *event_shape)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=event_shape)), + Flow(NICE(event_shape=event_shape)), + Flow(CouplingRQNSF(event_shape=event_shape)) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py new file mode 100644 index 0000000..08bb548 --- /dev/null +++ b/test/test_multiscale_bijections.py @@ -0,0 +1,109 @@ +from normalizing_flows.architectures import ( + MultiscaleNICE, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleRealNVP +) +import torch +import pytest + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=False) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=3, factored=False) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 32, 32), (3, 32, 32)]) +def test_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=True) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 15, 32), (3, 15, 32)]) +def test_factored_wrong_shape(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=2, factored=True) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 8, 8), (3, 8, 8)]) +def test_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=8, factored=True) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 4, 4), (3, 4, 4)]) +def test_non_factored_automatic_n_layers(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, factored=False) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 4, 8), (3, 4, 4)]) +def test_factored_automatic_n_layers(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, factored=True) diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py new file mode 100644 index 0000000..6952f34 --- /dev/null +++ b/test/test_squeeze_bijection.py @@ -0,0 +1,21 @@ +import torch +import pytest + +from normalizing_flows.bijections.finite.multiscale.base import Squeeze + + +@pytest.mark.parametrize('batch_shape', [(1,), (2,), (2, 3)]) +@pytest.mark.parametrize('channels', [1, 3, 10]) +@pytest.mark.parametrize('height', [4, 16, 32]) +@pytest.mark.parametrize('width', [4, 16, 32]) +def test_reconstruction(batch_shape, channels, height, width): + torch.manual_seed(0) + x = torch.randn(size=(*batch_shape, channels, height, width)) + layer = Squeeze(event_shape=x.shape[-3:]) + z, log_det_forward = layer.forward(x) + x_reconstructed, log_det_inverse = layer.inverse(z) + + assert z.shape == (*batch_shape, 4 * channels, height // 2, width // 2) + assert torch.allclose(x_reconstructed, x) + assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward)) + assert torch.allclose(log_det_forward, log_det_inverse)