From 9460d6b682daa35c21277d0dad2ef25cc6d0d958 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 5 Feb 2024 15:55:45 +0100 Subject: [PATCH 01/50] Add device getter for NFs --- normalizing_flows/flows.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 444e6f8..37fec95 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -29,6 +29,9 @@ def __init__(self, bijection: Bijection): self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) + def get_device(self): + return self.loc.device + @property def base(self) -> torch.distributions.Distribution: """ From e1eae9181ba81f3dc3ea0f67766065e8c28aae86 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 5 Feb 2024 15:56:33 +0100 Subject: [PATCH 02/50] Handle device in OT-Flow --- normalizing_flows/bijections/continuous/otflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/bijections/continuous/otflow.py b/normalizing_flows/bijections/continuous/otflow.py index 5021504..cc7946b 100644 --- a/normalizing_flows/bijections/continuous/otflow.py +++ b/normalizing_flows/bijections/continuous/otflow.py @@ -40,11 +40,11 @@ def __init__(self, event_size: int, hidden_size: int, step_size: float = 0.01): @property def K0(self): - return torch.eye(*self.K0_delta.shape) + self.K0_delta / 1000 + return torch.eye(*self.K0_delta.shape).to(self.K0_delta) + self.K0_delta / 1000 @property def K1(self): - return torch.eye(*self.K1_delta.shape) + self.K1_delta / 1000 + return torch.eye(*self.K1_delta.shape).to(self.K1_delta) + self.K1_delta / 1000 @staticmethod def sigma(x): @@ -113,7 +113,7 @@ def hessian_trace(self, # Compute the first term in Equation 14 - ones = torch.ones(size=(self.K0.shape[1] - 1,)) + ones = torch.ones(size=(self.K0.shape[1] - 1,)).to(s) t0 = torch.sum( torch.multiply( From a2c37de518cc8da4e921a828b9035a3c2254f765 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 5 Feb 2024 16:17:09 +0100 Subject: [PATCH 03/50] Ensure hidden size is an integer --- normalizing_flows/bijections/continuous/otflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/continuous/otflow.py b/normalizing_flows/bijections/continuous/otflow.py index cc7946b..976cbd9 100644 --- a/normalizing_flows/bijections/continuous/otflow.py +++ b/normalizing_flows/bijections/continuous/otflow.py @@ -146,7 +146,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs): # hidden_size = m if hidden_size is None: - hidden_size = max(math.log(event_size), 4) + hidden_size = max(int(math.log(event_size)), 4) r = min(10, event_size) From a43097ee54d747a512a41964ab586f783f2636eb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 11:38:47 +0100 Subject: [PATCH 04/50] Allow unused constant dimensions in coupling masks --- .../conditioning/coupling_masks.py | 60 +++++++++++-------- .../finite/autoregressive/layers.py | 48 +++++++-------- .../finite/autoregressive/layers_base.py | 24 +++++--- 3 files changed, 75 insertions(+), 57 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 04ddc8f..7b829f2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -1,44 +1,56 @@ import torch -class CouplingMask: +class PartialCoupling: """ - Base object which holds coupling partition mask information. + Coupling mask object where a part of dimensions is kept unchanged and does not affect other dimensions. """ - def __init__(self, event_shape): + def __init__(self, + event_shape, + source_mask: torch.Tensor, + target_mask: torch): + """ + Partial coupling mask constructor. + + :param source_mask: boolean mask tensor of dimensions that affect target dimensions. This tensor has shape + event_shape. + :param target_mask: boolean mask tensor of affected dimensions. This tensor has shape event_shape. + """ self.event_shape = event_shape + self.source_mask = source_mask + self.target_mask = target_mask + self.event_size = int(torch.prod(torch.as_tensor(self.event_shape))) @property - def mask(self): - raise NotImplementedError + def ignored_event_size(self): + # Event size of ignored dimensions. + return torch.sum(1 - (self.source_mask + self.target_mask)) @property - def constant_event_size(self): - raise NotImplementedError + def source_event_size(self): + return int(torch.sum(self.source_mask)) @property - def transformed_event_size(self): - raise NotImplementedError + def target_event_size(self): + return int(torch.sum(self.target_mask)) -class HalfSplit(CouplingMask): - def __init__(self, event_shape): - super().__init__(event_shape) - self.event_partition_mask = torch.less( - torch.arange(self.event_size).view(*self.event_shape), - self.constant_event_size - ) +class Coupling(PartialCoupling): + """ + Base object which holds coupling partition mask information. + """ - @property - def constant_event_size(self): - return self.event_size // 2 + def __init__(self, event_shape, mask: torch.Tensor): + super().__init__(event_shape, source_mask=mask, target_mask=~mask) @property - def transformed_event_size(self): - return self.event_size - self.constant_event_size + def ignored_event_size(self): + return 0 - @property - def mask(self): - return self.event_partition_mask + +class HalfSplit(Coupling): + def __init__(self, event_shape): + event_size = int(torch.prod(torch.as_tensor(event_shape))) + super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index f5ba2c7..ca44d40 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -48,15 +48,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) + coupling = HalfSplit(event_shape) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class InverseAffineCoupling(CouplingBijection): @@ -66,15 +66,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() + coupling = HalfSplit(event_shape) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))).invert() conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class ShiftCoupling(CouplingBijection): @@ -82,15 +82,15 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - coupling_mask = HalfSplit(event_shape) - transformer = Shift(event_shape=torch.Size((coupling_mask.transformed_event_size,))) + coupling = HalfSplit(event_shape) + transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class LRSCoupling(CouplingBijection): @@ -100,15 +100,15 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 - coupling_mask = HalfSplit(event_shape) - transformer = LinearRational(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) + coupling = HalfSplit(event_shape) + transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class RQSCoupling(CouplingBijection): @@ -117,15 +117,15 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - coupling_mask = HalfSplit(event_shape) - transformer = RationalQuadratic(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) + coupling = HalfSplit(event_shape) + transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class DSCoupling(CouplingBijection): @@ -134,20 +134,20 @@ def __init__(self, context_shape: torch.Size = None, n_hidden_layers: int = 2, **kwargs): - coupling_mask = HalfSplit(event_shape) + coupling = HalfSplit(event_shape) transformer = DeepSigmoid( - event_shape=torch.Size((coupling_mask.transformed_event_size,)), + event_shape=torch.Size((coupling.target_event_size,)), n_hidden_layers=n_hidden_layers ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling_mask.constant_event_size,)), + input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) - super().__init__(transformer, coupling_mask, conditioner_transform) + super().__init__(transformer, coupling, conditioner_transform) class LinearAffineCoupling(AffineCoupling): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 6562dda..2bb1cb4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -5,7 +5,7 @@ from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \ MADE -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import CouplingMask +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape @@ -52,14 +52,14 @@ class CouplingBijection(AutoregressiveBijection): def __init__(self, transformer: TensorTransformer, - coupling_mask: CouplingMask, + coupling: Coupling, conditioner_transform: ConditionerTransform, **kwargs): - super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs) - self.coupling_mask = coupling_mask + super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + self.coupling = coupling - assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,) - assert transformer.event_shape == (self.coupling_mask.transformed_event_size,) + assert conditioner_transform.input_event_shape == (coupling.source_event_size,) + assert transformer.event_shape == (self.coupling.target_event_size,) def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): """ @@ -70,20 +70,26 @@ def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tenso :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). """ # Predict transformer parameters for output dimensions - x_a = x[..., self.coupling_mask.mask] # (*b, constant_event_size) + x_a = x[..., self.coupling.source_mask] # (*b, constant_event_size) h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() h_b = self.partition_and_predict_parameters(x, context) - z[..., ~self.coupling_mask.mask], log_det = self.transformer.forward(x[..., ~self.coupling_mask.mask], h_b) + z[..., self.coupling.target_mask], log_det = self.transformer.forward( + x[..., self.coupling.target_mask], + h_b + ) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() h_b = self.partition_and_predict_parameters(x, context) - x[..., ~self.coupling_mask.mask], log_det = self.transformer.inverse(z[..., ~self.coupling_mask.mask], h_b) + x[..., self.coupling.target_mask], log_det = self.transformer.inverse( + z[..., self.coupling.target_mask], + h_b + ) return x, log_det From 00e9c8057c8810063c34d14583f9106f5e6917f5 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 11:49:19 +0100 Subject: [PATCH 05/50] Add graphical affine coupling layer --- .../conditioning/coupling_masks.py | 9 ++++++++ .../finite/autoregressive/layers.py | 23 ++++++++++++++++++- .../finite/autoregressive/layers_base.py | 4 ++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 7b829f2..422c8dc 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -1,3 +1,5 @@ +from typing import Tuple, List + import torch @@ -54,3 +56,10 @@ class HalfSplit(Coupling): def __init__(self, event_shape): event_size = int(torch.prod(torch.as_tensor(event_shape))) super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) + + +class GraphicalCoupling(PartialCoupling): + def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): + source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + super().__init__(event_shape, source_mask, target_mask) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index ca44d40..2b29ba3 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,7 +1,9 @@ +from typing import Tuple, List + import torch from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift @@ -59,6 +61,25 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class GraphicalAffine(CouplingBijection): + def __init__(self, + event_shape: torch.Size, + edge_list: List[Tuple[int, int]] = None, + context_shape: torch.Size = None, + **kwargs): + if event_shape == (1,): + raise ValueError + coupling = GraphicalCoupling(event_shape, edge_list=edge_list) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + context_shape=context_shape, + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform) + + class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 2bb1cb4..0b18b34 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -5,7 +5,7 @@ from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \ MADE -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape @@ -52,7 +52,7 @@ class CouplingBijection(AutoregressiveBijection): def __init__(self, transformer: TensorTransformer, - coupling: Coupling, + coupling: PartialCoupling, conditioner_transform: ConditionerTransform, **kwargs): super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) From c0104a27dff754617ffd49c3d3929e8fc8ee9a76 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:13:28 +0100 Subject: [PATCH 06/50] Rework graphical coupling --- .../finite/autoregressive/architectures.py | 29 +++++++------- .../conditioning/coupling_masks.py | 17 ++++++--- .../finite/autoregressive/layers.py | 38 ++++++------------- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index ca31f0a..8ad5838 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -1,3 +1,5 @@ +from typing import Tuple, List + from normalizing_flows.bijections.finite.autoregressive.layers import ( ShiftCoupling, AffineCoupling, @@ -11,49 +13,50 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive, ElementwiseShift + LRSForwardMaskedAutoregressive, + ElementwiseShift ) from normalizing_flows.bijections.base import BijectiveComposition from normalizing_flows.bijections.finite.linear import ReversePermutation class NICE(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - ShiftCoupling(event_shape=event_shape) + ShiftCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) class RealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - AffineCoupling(event_shape=event_shape) + AffineCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) class InverseRealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - InverseAffineCoupling(event_shape=event_shape) + InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -92,14 +95,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingRQNSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - RQSCoupling(event_shape=event_shape) + RQSCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -124,14 +127,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingLRS(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - LRSCoupling(event_shape=event_shape) + LRSCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -166,14 +169,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters + DSCoupling(event_shape=event_shape, edge_list=edge_list) # TODO specify percent of global parameters ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 422c8dc..1621732 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -52,14 +52,21 @@ def ignored_event_size(self): return 0 +class GraphicalCoupling(PartialCoupling): + def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): + source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + super().__init__(event_shape, source_mask, target_mask) + + class HalfSplit(Coupling): def __init__(self, event_shape): event_size = int(torch.prod(torch.as_tensor(event_shape))) super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) -class GraphicalCoupling(PartialCoupling): - def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): - source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) - target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) - super().__init__(event_shape, source_mask, target_mask) +def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None): + if edge_list is None: + return HalfSplit(event_shape) + else: + return GraphicalCoupling(event_shape, edge_list) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 2b29ba3..d6d613e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -3,7 +3,7 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift @@ -18,7 +18,6 @@ from normalizing_flows.bijections.base import invert -# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files. class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) @@ -47,29 +46,11 @@ class AffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, - **kwargs): - if event_shape == (1,): - raise ValueError - coupling = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) - - -class GraphicalAffine(CouplingBijection): - def __init__(self, - event_shape: torch.Size, edge_list: List[Tuple[int, int]] = None, - context_shape: torch.Size = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = GraphicalCoupling(event_shape, edge_list=edge_list) + coupling = make_coupling(event_shape, edge_list) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -84,10 +65,11 @@ class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))).invert() conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -102,8 +84,9 @@ class ShiftCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -119,9 +102,10 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, **kwargs): assert n_bins >= 1 - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -137,8 +121,9 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -154,8 +139,9 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = DeepSigmoid( event_shape=torch.Size((coupling.target_event_size,)), n_hidden_layers=n_hidden_layers From 5a00e38ddf15e28f1946e7811bc13b787193bb13 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:40:25 +0100 Subject: [PATCH 07/50] Add graphical coupling tests, fix graphical coupling masks --- .../conditioning/coupling_masks.py | 12 ++- test/test_graphical_normalizing_flow.py | 85 +++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 test/test_graphical_normalizing_flow.py diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 1621732..7ac5b70 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -54,8 +54,16 @@ def ignored_event_size(self): class GraphicalCoupling(PartialCoupling): def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): - source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) - target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + if len(event_shape) != 1: + raise ValueError("GraphicalCoupling is currently only implemented for vector data") + + source_dims = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_dims = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + + event_size = int(torch.prod(torch.as_tensor(event_shape))) + source_mask = torch.isin(torch.arange(event_size), source_dims) + target_mask = torch.isin(torch.arange(event_size), target_dims) + super().__init__(event_shape, source_mask, target_mask) diff --git a/test/test_graphical_normalizing_flow.py b/test/test_graphical_normalizing_flow.py new file mode 100644 index 0000000..2b92025 --- /dev/null +++ b/test/test_graphical_normalizing_flow.py @@ -0,0 +1,85 @@ +import torch +from normalizing_flows.architectures import RealNVP + + +def test_basic_2d(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 2 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d_2(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d_3(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + + +def test_random(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 30 + x = torch.randn(size=(n_data, n_dim)) + + interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,))) + interacting_dimensions = interacting_dimensions[torch.randperm(len(interacting_dimensions))] + source_dimensions = interacting_dimensions[:len(interacting_dimensions) // 2] + target_dimensions = interacting_dimensions[len(interacting_dimensions) // 2:] + + edge_list = [] + for s in source_dimensions: + for t in target_dimensions: + edge_list.append((s, t)) + + bijection = RealNVP(event_shape=(n_dim,), edge_list=edge_list) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" From df810dae1d729530299925bdaad2d97d7f7b1153 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:43:37 +0100 Subject: [PATCH 08/50] Expand graphical coupling tests --- test/test_graphical_normalizing_flow.py | 44 ++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/test/test_graphical_normalizing_flow.py b/test/test_graphical_normalizing_flow.py index 2b92025..0c465f8 100644 --- a/test/test_graphical_normalizing_flow.py +++ b/test/test_graphical_normalizing_flow.py @@ -1,69 +1,75 @@ +import pytest import torch -from normalizing_flows.architectures import RealNVP +from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF -def test_basic_2d(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_2d(architecture): torch.manual_seed(0) n_data = 100 n_dim = 2 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4) assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d_2(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_2(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4) assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d_3(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_3(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse, - atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" -def test_random(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_random(architecture): torch.manual_seed(0) n_data = 100 - n_dim = 30 + n_dim = 50 x = torch.randn(size=(n_data, n_dim)) interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,))) @@ -76,10 +82,10 @@ def test_random(): for t in target_dimensions: edge_list.append((s, t)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=edge_list) + bijection = architecture(event_shape=(n_dim,), edge_list=edge_list) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse, - atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" From 92b25186526c9591074d5d5ba693afe7aad77878 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:46:53 +0100 Subject: [PATCH 09/50] Remove separate DDNF class --- normalizing_flows/flows.py | 74 -------------------------------------- 1 file changed, 74 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 37fec95..fe27fab 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -291,77 +291,3 @@ def variational_fit(self, loss.backward() optimizer.step() iterator.set_postfix_str(f'Variational loss: {loss:.4f}') - - -class DDNF(Flow): - """ - Deep diffeomorphic normalizing flow. - - Salman et al. Deep diffeomorphic normalizing flows (2018). - """ - - def __init__(self, event_shape: torch.Size, **kwargs): - bijection = DeepDiffeomorphicBijection(event_shape=event_shape, **kwargs) - super().__init__(bijection) - - def fit(self, - x_train: torch.Tensor, - n_epochs: int = 500, - lr: float = 0.05, - batch_size: int = 1024, - shuffle: bool = True, - show_progress: bool = False, - w_train: torch.Tensor = None, - rec_err_coef: float = 1.0): - """ - - :param x_train: - :param n_epochs: - :param lr: learning rate. In general, lower learning rates are recommended for high-parametric bijections. - :param batch_size: - :param shuffle: - :param show_progress: - :param w_train: training data weights - :param rec_err_coef: reconstruction error regularization coefficient. - :return: - """ - if w_train is None: - batch_shape = get_batch_shape(x_train, self.bijection.event_shape) - w_train = torch.ones(batch_shape) - if batch_size is None: - batch_size = len(x_train) - optimizer = torch.optim.AdamW(self.parameters(), lr=lr) - dataset = TensorDataset(x_train, w_train) - data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) - - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) - - if show_progress: - iterator = tqdm(range(n_epochs), desc='Fitting NF') - else: - iterator = range(n_epochs) - - for _ in iterator: - for batch_x, batch_w in data_loader: - optimizer.zero_grad() - - z, log_prob = self.forward_with_log_prob(batch_x.to(self.loc)) # TODO context! - w = batch_w.to(self.loc) - assert log_prob.shape == w.shape - loss = -torch.mean(log_prob * w) / n_event_dims - - if hasattr(self.bijection, 'regularization'): - # Always true for DeepDiffeomorphicBijection, but we keep it for clarity - loss += self.bijection.regularization() - - # Inverse consistency regularization - x_reconstructed = self.bijection.inverse(z) - loss += reconstruction_error(batch_x, x_reconstructed, self.bijection.event_shape, rec_err_coef) - - # Geodesic regularization - - loss.backward() - optimizer.step() - - if show_progress: - iterator.set_postfix_str(f'Loss: {loss:.4f}') From 474460fb51982f4a6b2ef315c2ecff8dad4cc036 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 13:51:25 +0100 Subject: [PATCH 10/50] Add FlowMixture class, rework flow class --- normalizing_flows/flows.py | 253 ++++++++++++++++++++++--------------- 1 file changed, 154 insertions(+), 99 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index fe27fab..788af2d 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,33 +1,21 @@ from copy import deepcopy -from typing import Union, Tuple +from typing import Union, Tuple, List +import numpy as np import torch import torch.nn as nn -from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm from normalizing_flows.bijections.base import Bijection -from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection -from normalizing_flows.regularization import reconstruction_error -from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event, create_data_loader +from normalizing_flows.utils import flatten_event, unflatten_event, create_data_loader -class Flow(nn.Module): - """ - Normalizing flow class. - - This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). - A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. - """ - - def __init__(self, bijection: Bijection): - """ - - :param bijection: transformation component of the normalizing flow. - """ +class BaseFlow(nn.Module): + def __init__(self, event_shape): super().__init__() - self.register_module('bijection', bijection) - self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) - self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) + self.event_shape = event_shape + self.event_size = int(torch.prod(torch.as_tensor(event_shape))) + self.register_buffer('loc', torch.zeros(self.event_size)) + self.register_buffer('covariance_matrix', torch.eye(self.event_size)) def get_device(self): return self.loc.device @@ -46,7 +34,7 @@ def base_log_prob(self, z: torch.Tensor): :param z: input tensor. :return: log probability of the input tensor. """ - zf = flatten_event(z, self.bijection.event_shape) + zf = flatten_event(z, self.event_shape) log_prob = self.base.log_prob(zf) return log_prob @@ -58,65 +46,11 @@ def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]): :return: tensor with shape sample_shape. """ z_flat = self.base.sample(sample_shape) - z = unflatten_event(z_flat, self.bijection.event_shape) + z = unflatten_event(z_flat, self.event_shape) return z - def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Transform the input x to the space of the base distribution. - - :param x: input tensor. - :param context: context tensor upon which the transformation is conditioned. - :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the - transformation. - """ - if context is not None: - assert context.shape[0] == x.shape[0] - context = context.to(self.loc) - z, log_det = self.bijection.forward(x.to(self.loc), context=context) - log_base = self.base_log_prob(z) - return z, log_base + log_det - - def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Compute the logarithm of the probability density of input x according to the normalizing flow. - - :param x: input tensor. - :param context: context tensor. - :return: - """ - return self.forward_with_log_prob(x, context)[1] - - def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): - """ - Sample from the normalizing flow. - - If context given, sample n tensors for each context tensor. - Otherwise, sample n tensors. - - :param n: number of tensors to sample. - :param context: context tensor with shape c. - :param no_grad: if True, do not track gradients in the inverse pass. - :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. - """ - if context is not None: - z = self.base_sample(sample_shape=torch.Size((n, len(context)))) - context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape - assert z.shape[:2] == context.shape[:2] - else: - z = self.base_sample(sample_shape=torch.Size((n,))) - if no_grad: - z = z.detach() - with torch.no_grad(): - x, log_det = self.bijection.inverse(z, context=context) - else: - x, log_det = self.bijection.inverse(z, context=context) - x = x.to(self.loc) - - if return_log_prob: - log_prob = self.base_log_prob(z) + log_det - return x, log_prob - return x + def regularization(self): + return 0.0 def fit(self, x_train: torch.Tensor, @@ -155,10 +89,7 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ - self.bijection.train() - - # Compute the number of event dimensions - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) + self.train() # Set the default batch size if batch_size is None: @@ -172,7 +103,7 @@ def fit(self, "training", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) # Process validation data @@ -184,7 +115,7 @@ def fit(self, "validation", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) best_val_loss = torch.inf @@ -195,10 +126,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] batch_context = batch_[2] if len(batch_) == 3 else None - batch_log_prob = self.log_prob(batch_x.to(self.loc), context=batch_context) - batch_weights = batch_weights.to(self.loc) + batch_log_prob = self.log_prob(batch_x.to(self.get_device()), context=batch_context) + batch_weights = batch_weights.to(self.get_device()) assert batch_log_prob.shape == batch_weights.shape, f"{batch_log_prob.shape = }, {batch_weights.shape = }" - batch_loss = -reduction(batch_log_prob * batch_weights) / n_event_dims + batch_loss = -reduction(batch_log_prob * batch_weights) / self.event_size return batch_loss @@ -210,8 +141,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) - if hasattr(self.bijection, 'regularization'): - train_loss += self.bijection.regularization() + train_loss += self.regularization() train_loss.backward() optimizer.step() @@ -233,8 +163,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for val_batch in val_loader: n_batch_data = len(val_batch[0]) val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data - if hasattr(self.bijection, 'regularization'): - val_loss += self.bijection.regularization() + val_loss += self.regularization() # Check if validation loss is the lowest so far if val_loss < best_val_loss: @@ -254,7 +183,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) - self.bijection.eval() + self.eval() def variational_fit(self, target_log_prob: callable, @@ -263,8 +192,8 @@ def variational_fit(self, n_samples: int = 1000, show_progress: bool = False): """ - Train a normalizing flow with stochastic variational inference. - Stochastic variational inference lets us train a normalizing flow using the unnormalized target log density + Train a distribution with stochastic variational inference. + Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details @@ -279,15 +208,141 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ - iterator = tqdm(range(n_epochs), desc='Variational NF fit', disable=not show_progress) + iterator = tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) for _ in iterator: optimizer.zero_grad() flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob) - if hasattr(self.bijection, 'regularization'): - loss += self.bijection.regularization() + loss += self.regularization() loss.backward() optimizer.step() - iterator.set_postfix_str(f'Variational loss: {loss:.4f}') + iterator.set_postfix_str(f'Loss: {loss:.4f}') + + +class Flow(BaseFlow): + """ + Normalizing flow class. + + This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). + A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. + """ + + def __init__(self, bijection: Bijection): + """ + + :param bijection: transformation component of the normalizing flow. + """ + super().__init__(event_shape=bijection.event_shape) + self.register_module('bijection', bijection) + + def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Transform the input x to the space of the base distribution. + + :param x: input tensor. + :param context: context tensor upon which the transformation is conditioned. + :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the + transformation. + """ + if context is not None: + assert context.shape[0] == x.shape[0] + context = context.to(self.get_device()) + z, log_det = self.bijection.forward(x.to(self.get_device()), context=context) + log_base = self.base_log_prob(z) + return z, log_base + log_det + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Compute the logarithm of the probability density of input x according to the normalizing flow. + + :param x: input tensor. + :param context: context tensor. + :return: + """ + return self.forward_with_log_prob(x, context)[1] + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + """ + Sample from the normalizing flow. + + If context given, sample n tensors for each context tensor. + Otherwise, sample n tensors. + + :param n: number of tensors to sample. + :param context: context tensor with shape c. + :param no_grad: if True, do not track gradients in the inverse pass. + :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. + """ + if context is not None: + z = self.base_sample(sample_shape=torch.Size((n, len(context)))) + context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape + assert z.shape[:2] == context.shape[:2] + else: + z = self.base_sample(sample_shape=torch.Size((n,))) + if no_grad: + z = z.detach() + with torch.no_grad(): + x, log_det = self.bijection.inverse(z, context=context) + else: + x, log_det = self.bijection.inverse(z, context=context) + x = x.to(self.get_device()) + + if return_log_prob: + log_prob = self.base_log_prob(z) + log_det + return x, log_prob + return x + + def regularization(self): + if hasattr(self.bijection, 'regularization'): + return self.bijection.regularization() + else: + return 0.0 + + +class FlowMixture(BaseFlow): + def __init__(self, flows: List[Flow], weights: List[float]): + super().__init__(event_shape=flows[0].event_shape) + assert len(weights) == len(flows) + assert all([w > 0.0 for w in weights]) + assert np.isclose(sum(weights), 1.0) + + self.flows = flows + self.weights = torch.tensor(weights) + self.log_weights = torch.log(self.weights) + self.categorical_distribution = torch.distributions.Categorical(probs=self.weights) + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + flow_log_probs = torch.stack([flow.log_prob(x, context=context) for flow in self.flows]) + # (n_flows, *batch_shape) + + batch_shape = flow_log_probs.shape[1:] + log_weights_reshaped = self.log_weights.view(-1, *([1] * len(batch_shape))) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # batch_shape + return log_prob + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + flow_samples = [] + flow_log_probs = [] + for flow in self.flows: + flow_x, flow_log_prob = flow.sample(n, context=context, no_grad=no_grad, return_log_prob=True) + flow_samples.append(flow_x) + flow_log_probs.append(flow_log_prob) + + with torch.no_grad(): + flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) + categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,) + one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) + samples = torch.sum(one_hot * flow_samples, dim=0) # (n, *event_shape) + + if return_log_prob: + flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) + log_weights_reshaped = self.log_weights[:, None] # (n_flows, 1) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # (n,) + return samples, log_prob + else: + return samples + + def regularization(self): + return sum([flow.regularization() for flow in self.flows]) From 4fb4a00e9d0432a11690dfd2a1c168288cf48f6b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 14:05:26 +0100 Subject: [PATCH 11/50] FlowMixture fixes and tests --- normalizing_flows/flows.py | 12 +++++-- test/test_mixture.py | 69 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 test/test_mixture.py diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 788af2d..d3056fc 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -302,8 +302,13 @@ def regularization(self): class FlowMixture(BaseFlow): - def __init__(self, flows: List[Flow], weights: List[float]): + def __init__(self, flows: List[Flow], weights: List[float] = None): super().__init__(event_shape=flows[0].event_shape) + + # Use uniform weights by default + if weights is None: + weights = [1.0 / len(flows)] * len(flows) + assert len(weights) == len(flows) assert all([w > 0.0 for w in weights]) assert np.isclose(sum(weights), 1.0) @@ -334,7 +339,10 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,) one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) - samples = torch.sum(one_hot * flow_samples, dim=0) # (n, *event_shape) + one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) + # (n_flows, n, *event_shape) + + samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) if return_log_prob: flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) diff --git a/test/test_mixture.py b/test/test_mixture.py new file mode 100644 index 0000000..eed7ff0 --- /dev/null +++ b/test/test_mixture.py @@ -0,0 +1,69 @@ +from normalizing_flows.flows import FlowMixture, Flow +from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF +import torch + + +def test_basic(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 10 + x = torch.randn(size=(n_data, n_dim)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=(n_dim,))), + Flow(NICE(event_shape=(n_dim,))), + Flow(CouplingRQNSF(event_shape=(n_dim,))) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) + + +def test_medium(): + torch.manual_seed(0) + + n_data = 1000 + n_dim = 100 + x = torch.randn(size=(n_data, n_dim)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=(n_dim,))), + Flow(NICE(event_shape=(n_dim,))), + Flow(CouplingRQNSF(event_shape=(n_dim,))) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) + + +def test_complex_event(): + torch.manual_seed(0) + + n_data = 1000 + event_shape = (2, 3, 4, 5) + x = torch.randn(size=(n_data, *event_shape)) + + mixture = FlowMixture([ + Flow(RealNVP(event_shape=event_shape)), + Flow(NICE(event_shape=event_shape)), + Flow(CouplingRQNSF(event_shape=event_shape)) + ]) + + log_prob = mixture.log_prob(x) + assert log_prob.shape == (n_data,) + assert torch.all(torch.isfinite(log_prob)) + + x_sampled = mixture.sample(n_data) + assert x_sampled.shape == x.shape + assert torch.all(torch.isfinite(x_sampled)) From bf0529e1f447dc4b42856b099ee900e4ce1d9079 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 14:43:01 +0100 Subject: [PATCH 12/50] Add learnable weights in FlowMixture --- normalizing_flows/flows.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index d3056fc..465047c 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -302,7 +302,7 @@ def regularization(self): class FlowMixture(BaseFlow): - def __init__(self, flows: List[Flow], weights: List[float] = None): + def __init__(self, flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False): super().__init__(event_shape=flows[0].event_shape) # Use uniform weights by default @@ -313,17 +313,18 @@ def __init__(self, flows: List[Flow], weights: List[float] = None): assert all([w > 0.0 for w in weights]) assert np.isclose(sum(weights), 1.0) - self.flows = flows - self.weights = torch.tensor(weights) - self.log_weights = torch.log(self.weights) - self.categorical_distribution = torch.distributions.Categorical(probs=self.weights) + self.flows = nn.ModuleList(flows) + if trainable_weights: + self.logit_weights = nn.Parameter(torch.log(torch.tensor(weights))) + else: + self.logit_weights = torch.log(torch.tensor(weights)) def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): flow_log_probs = torch.stack([flow.log_prob(x, context=context) for flow in self.flows]) # (n_flows, *batch_shape) batch_shape = flow_log_probs.shape[1:] - log_weights_reshaped = self.log_weights.view(-1, *([1] * len(batch_shape))) + log_weights_reshaped = self.logit_weights.view(-1, *([1] * len(batch_shape))) log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # batch_shape return log_prob @@ -335,18 +336,19 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re flow_samples.append(flow_x) flow_log_probs.append(flow_log_prob) - with torch.no_grad(): - flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) - categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,) - one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) - one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) - # (n_flows, n, *event_shape) + flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) + categorical_samples = torch.distributions.Categorical(logits=self.logit_weights).sample( + sample_shape=torch.Size((n,)) + ) # (n,) + one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) + one_hot_reshaped = one_hot.view(*one_hot.shape, *([1] * len(self.event_shape))) + # (n_flows, n, *event_shape) - samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) + samples = torch.sum(one_hot_reshaped * flow_samples, dim=0) # (n, *event_shape) if return_log_prob: flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) - log_weights_reshaped = self.log_weights[:, None] # (n_flows, 1) + log_weights_reshaped = self.logit_weights[:, None] # (n_flows, 1) log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # (n,) return samples, log_prob else: From 07261cbfa13369c6c59f3956997ceb03565f9f09 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 17:28:21 +0100 Subject: [PATCH 13/50] Remove permutations when using graphical coupling --- .../finite/autoregressive/architectures.py | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index 8ad5838..78467a3 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -26,10 +26,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - ShiftCoupling(event_shape=event_shape, edge_list=edge_list) - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(ShiftCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -40,10 +39,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineCoupling(event_shape=event_shape, edge_list=edge_list) - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(AffineCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -54,10 +52,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list) - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -100,10 +97,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSCoupling(event_shape=event_shape, edge_list=edge_list) - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(RQSCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -132,10 +128,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - LRSCoupling(event_shape=event_shape, edge_list=edge_list) - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(LRSCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -174,10 +169,9 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, in event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape, edge_list=edge_list) # TODO specify percent of global parameters - ]) + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(DSCoupling(event_shape=event_shape, edge_list=edge_list)) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) From ba478476af25b7746952dd5501eb5d70aef0ce81 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 22 Feb 2024 11:08:56 +0100 Subject: [PATCH 14/50] Have bijective compositions accept kwargs --- normalizing_flows/bijections/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a05fe5a..146bff2 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -11,7 +11,8 @@ class Bijection(nn.Module): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): """ Bijection class. """ @@ -93,7 +94,8 @@ class BijectiveComposition(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): super().__init__(event_shape=event_shape, context_shape=context_shape) self.layers = nn.ModuleList(layers) From 2022caa7fe711889753bcfe688e54f044268ce9e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 8 Mar 2024 15:21:05 +0100 Subject: [PATCH 15/50] Add regularized graphical conditioner --- .../autoregressive/conditioning/transforms.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index 74f3966..1b373f6 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -284,3 +284,68 @@ def __init__(self, def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): return self.sequential(self.context_combiner(x, context)) + + +class CombinedConditioner(nn.Module): + """ + Class that uses two different conditioners (each acting on different dimensions) to predict transformation + parameters. Transformation parameters are combined in a single vector. + """ + + def __init__(self, + conditioner1: ConditionerTransform, + conditioner2: ConditionerTransform, + conditioner1_input_mask: torch.Tensor, + conditioner2_input_mask: torch.Tensor): + super().__init__() + self.conditioner1 = conditioner1 + self.conditioner2 = conditioner2 + self.mask1 = conditioner1_input_mask + self.mask2 = conditioner2_input_mask + + def forward(self, x: torch.Tensor, context: torch.Tensor = None): + h1 = self.conditioner1(x[..., self.mask1], context) + h2 = self.conditioner1(x[..., self.mask1], context) + return h1 + h2 + + +class RegularizedCombinedConditioner(CombinedConditioner): + def __init__(self, + conditioner1: ConditionerTransform, + conditioner2: ConditionerTransform, + conditioner1_input_mask: torch.Tensor, + conditioner2_input_mask: torch.Tensor, + regularization_coefficient_1: float, + regularization_coefficient_2: float): + super().__init__( + conditioner1, + conditioner2, + conditioner1_input_mask, + conditioner2_input_mask + ) + self.c1 = regularization_coefficient_1 + self.c2 = regularization_coefficient_2 + + @property + def regularization(self): + # L2 aka Gaussian prior + c1_reg = self.c1 * sum([(p ** 2).sum() for p in self.conditioner1.parameters()]) + c2_reg = self.c2 * sum([(p ** 2).sum() for p in self.conditioner2.parameters()]) + return c1_reg + c2_reg + + +class RegularizedGraphicalConditioner(RegularizedCombinedConditioner): + def __init__(self, + interacting_dimensions_conditioner: ConditionerTransform, + auxiliary_dimensions_conditioner: ConditionerTransform, + interacting_dimensions_mask: torch.Tensor, + auxiliary_dimensions_mask: torch.Tensor, + coefficient: float = 0.1): + super().__init__( + interacting_dimensions_conditioner, + auxiliary_dimensions_conditioner, + interacting_dimensions_mask, + auxiliary_dimensions_mask, + regularization_coefficient_1=0.0, + regularization_coefficient_2=coefficient + ) From 06c53dd7e1baa8bc72dba547c5fb9cd796659513 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 11 Mar 2024 11:03:03 +0100 Subject: [PATCH 16/50] Add conditioner regularization --- .../finite/autoregressive/conditioning/transforms.py | 12 +++++++----- .../bijections/finite/autoregressive/layers_base.py | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index 1b373f6..f963e15 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -96,6 +96,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError + def regularization(self): + return sum([torch.sum(torch.square(p)) for p in self.parameters()]) + class Constant(ConditionerTransform): def __init__(self, event_shape, parameter_shape, fill_value: float = None): @@ -308,6 +311,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): h2 = self.conditioner1(x[..., self.mask1], context) return h1 + h2 + def regularization(self): + return self.conditioner1.regularization() + self.conditioner2.regularization() + class RegularizedCombinedConditioner(CombinedConditioner): def __init__(self, @@ -326,12 +332,8 @@ def __init__(self, self.c1 = regularization_coefficient_1 self.c2 = regularization_coefficient_2 - @property def regularization(self): - # L2 aka Gaussian prior - c1_reg = self.c1 * sum([(p ** 2).sum() for p in self.conditioner1.parameters()]) - c2_reg = self.c2 * sum([(p ** 2).sum() for p in self.conditioner2.parameters()]) - return c1_reg + c2_reg + return self.c1 * self.conditioner1.regularization() + self.c2 * self.conditioner2.regularization() class RegularizedGraphicalConditioner(RegularizedCombinedConditioner): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 0b18b34..f868f0d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -31,6 +31,9 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. x, log_det = self.transformer.inverse(z, h) return x, log_det + def regularization(self): + return self.conditioner_transform.regularization() + class CouplingBijection(AutoregressiveBijection): """ From f6e4abb28993f98bca6644cfcf6cb5051868fb1e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 11 Mar 2024 11:04:34 +0100 Subject: [PATCH 17/50] Add regularization to autoregressive architectures --- normalizing_flows/bijections/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index 146bff2..a989a21 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -114,3 +114,6 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tu log_det += log_det_layer x = z return x, log_det + + def regularization(self): + return sum([layer.regularization() for layer in self.layers]) From c22d58bcdb4093e8b8c9dd4c2dcdfe6ab2c3d6f1 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 16 Mar 2024 18:13:48 +0100 Subject: [PATCH 18/50] Handle edge cases for regularization --- normalizing_flows/bijections/base.py | 2 ++ .../bijections/finite/autoregressive/layers_base.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a989a21..a086705 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -81,6 +81,8 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor assert log_dets.shape == batch_shape return outputs, log_dets + def regularization(self): + return 0.0 def invert(bijection): """ diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index f868f0d..5c43108 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -32,7 +32,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return x, log_det def regularization(self): - return self.conditioner_transform.regularization() + return self.conditioner_transform.regularization() if self.conditioner_transform is not None else 0.0 class CouplingBijection(AutoregressiveBijection): From 3dab74e7faafcf6d617c0b1b7a2838f35ed82702 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 22 Mar 2024 00:41:24 +0100 Subject: [PATCH 19/50] Add test for identity bijections with maximum regularization --- test/test_identity_bijections.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/test_identity_bijections.py diff --git a/test/test_identity_bijections.py b/test/test_identity_bijections.py new file mode 100644 index 0000000..fa88e93 --- /dev/null +++ b/test/test_identity_bijections.py @@ -0,0 +1,67 @@ +# Check that when all bijection parameters are set to 0, the bijections reduce to an identity map + +from normalizing_flows.bijections.finite.autoregressive.layers import ( + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + UMNNMaskedAutoregressive, + +) +import torch +import pytest + + +@pytest.mark.parametrize( + 'layer_class', + [ + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + # UMNNMaskedAutoregressive, # Inexact due to numerics + ] +) +def test_basic(layer_class): + n_batch, n_dim = 2, 3 + + torch.manual_seed(0) + x = torch.randn(size=(n_batch, n_dim)) + layer = layer_class(event_shape=torch.Size((n_dim,))) + + # Set all conditioner parameters to 0 + with torch.no_grad(): + for p in layer.parameters(): + p.data *= 0 + + assert torch.allclose(layer(x)[0], x, atol=1e-2) From 8d6a5a7f22d5faa06470b36c1eec951ab794a242 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 22 Mar 2024 00:41:39 +0100 Subject: [PATCH 20/50] Fix bijection inversion call --- normalizing_flows/bijections/finite/autoregressive/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index d6d613e..1594177 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -70,7 +70,7 @@ def __init__(self, if event_shape == (1,): raise ValueError coupling = make_coupling(event_shape, edge_list) - transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))).invert() + transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), From 84ad0623ca42d0e1e93ce118b2df38daf812a571 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 6 Jun 2024 13:17:44 +0200 Subject: [PATCH 21/50] Add checkerboard mask and test --- .../conditioning/coupling_masks.py | 23 ++++++++++ test/test_checkerboard_coupling.py | 46 +++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 test/test_checkerboard_coupling.py diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 7ac5b70..508f459 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -73,6 +73,29 @@ def __init__(self, event_shape): super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) +class Checkerboard(Coupling): + """ + Checkerboard coupling for image data. + """ + + def __init__(self, event_shape, resolution: int = 2): + """ + :param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal + and a power of two. + :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two + and smaller than image width. + """ + n_channels, width, _ = event_shape + assert width % resolution == 0 + square_side_length = width // resolution + assert resolution % 2 == 0 + half_resolution = resolution // 2 + a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) + mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) + mask = mask.bool() + super().__init__(event_shape, mask) + + def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None): if edge_list is None: return HalfSplit(event_shape) diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py new file mode 100644 index 0000000..f4fbd23 --- /dev/null +++ b/test/test_checkerboard_coupling.py @@ -0,0 +1,46 @@ +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Checkerboard +import torch + + +def test_checkerboard_small(): + torch.manual_seed(0) + image_shape = (3, 4, 4) + coupling = Checkerboard(image_shape, resolution=2) + assert torch.allclose( + coupling.source_mask, + torch.tensor([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1], + ], dtype=torch.bool)[None].repeat(3, 1, 1) + ) + assert torch.allclose(coupling.target_mask, ~coupling.source_mask) + + +def test_checkerboard_medium(): + torch.manual_seed(0) + image_shape = (3, 16, 16) + coupling = Checkerboard(image_shape, resolution=4) + assert torch.allclose( + coupling.source_mask, + torch.tensor([ + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + ], dtype=torch.bool)[None].repeat(3, 1, 1) + ) + assert torch.allclose(coupling.target_mask, ~coupling.source_mask) From 3e36b6ebaa916a6ed6e5b1dc057f636ed01276d4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 6 Jun 2024 13:24:24 +0200 Subject: [PATCH 22/50] Add checkerboard inversion --- .../conditioning/coupling_masks.py | 5 ++++- test/test_checkerboard_coupling.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 508f459..603a3de 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -78,12 +78,13 @@ class Checkerboard(Coupling): Checkerboard coupling for image data. """ - def __init__(self, event_shape, resolution: int = 2): + def __init__(self, event_shape, resolution: int = 2, invert: bool = False): """ :param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal and a power of two. :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two and smaller than image width. + :param invert: invert the checkerboard mask. """ n_channels, width, _ = event_shape assert width % resolution == 0 @@ -93,6 +94,8 @@ def __init__(self, event_shape, resolution: int = 2): a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) mask = mask.bool() + if invert: + mask = ~mask super().__init__(event_shape, mask) diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py index f4fbd23..1f55fdf 100644 --- a/test/test_checkerboard_coupling.py +++ b/test/test_checkerboard_coupling.py @@ -44,3 +44,19 @@ def test_checkerboard_medium(): ], dtype=torch.bool)[None].repeat(3, 1, 1) ) assert torch.allclose(coupling.target_mask, ~coupling.source_mask) + + +def test_checkerboard_small_inverted(): + torch.manual_seed(0) + image_shape = (3, 4, 4) + coupling = Checkerboard(image_shape, resolution=2, invert=True) + assert torch.allclose( + coupling.source_mask, + ~torch.tensor([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1], + ], dtype=torch.bool)[None].repeat(3, 1, 1) + ) + assert torch.allclose(coupling.target_mask, ~coupling.source_mask) From 5f81473839842f97ab99cfe2394ad5c8f6b1bcda Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 6 Jun 2024 16:22:09 +0200 Subject: [PATCH 23/50] Add source and target shape to checkerboard mask --- .../finite/autoregressive/conditioning/coupling_masks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 603a3de..1a163c4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -80,13 +80,13 @@ class Checkerboard(Coupling): def __init__(self, event_shape, resolution: int = 2, invert: bool = False): """ - :param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two and smaller than image width. :param invert: invert the checkerboard mask. """ - n_channels, width, _ = event_shape + n_channels, height, width = event_shape assert width % resolution == 0 square_side_length = width // resolution assert resolution % 2 == 0 @@ -96,6 +96,8 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): mask = mask.bool() if invert: mask = ~mask + self.source_shape = (n_channels, height // resolution, width // resolution) + self.target_shape = (n_channels, height // resolution, width // resolution) super().__init__(event_shape, mask) From 11f80f0ecd84629156e496534859173e7f522386 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 13 Jun 2024 22:24:03 +0200 Subject: [PATCH 24/50] Minor refactor for NF layers to allow future image-based transforms --- .../finite/autoregressive/architectures.py | 163 +++++++++--------- .../conditioning/coupling_masks.py | 27 ++- .../autoregressive/conditioning/transforms.py | 3 +- .../finite/autoregressive/layers.py | 30 +++- 4 files changed, 131 insertions(+), 92 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index 78467a3..e366fbc 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -1,4 +1,4 @@ -from typing import Tuple, List +from typing import Tuple, List, Type, Union from normalizing_flows.bijections.finite.autoregressive.layers import ( ShiftCoupling, @@ -13,23 +13,85 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive, - ElementwiseShift + LRSForwardMaskedAutoregressive ) from normalizing_flows.bijections.base import BijectiveComposition +from normalizing_flows.bijections.finite.autoregressive.layers_base import CouplingBijection, \ + MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection from normalizing_flows.bijections.finite.linear import ReversePermutation +def make_basic_layers(base_bijection: Type[ + Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None): + """ + Returns a list of bijections for transformations of vectors. + """ + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + if edge_list is None: + bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list)) + bijections.append(ElementwiseAffine(event_shape=event_shape)) + return bijections + + +def make_image_layers(base_bijection: Type[ + Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], + event_shape, + checkerboard_resolution: int = 2, + n_layers: int = 2): + """ + Returns a list of bijections for transformations of images. + + Each layer consists of four coupling transforms: + 1. checkerboard, + 2. channel_wise, + 3. checkerboard_inverted, + 4. channel_wise_inverted. + """ + if len(event_shape) != 3: + raise ValueError("Image-based transformation are only possible for inputs with three axes.") + + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'checkerboard', + 'resolution': checkerboard_resolution, + } + )) + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'channel_wise' + } + )) + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'checkerboard_inverted', + 'resolution': checkerboard_resolution, + } + )) + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'channel_wise_inverted' + } + )) + bijections.append(ElementwiseAffine(event_shape=event_shape)) + return bijections + + class NICE(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(ShiftCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -37,12 +99,7 @@ class RealNVP(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(AffineCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -50,12 +107,7 @@ class InverseRealNVP(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -67,13 +119,7 @@ class MAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -81,13 +127,7 @@ class IAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - AffineInverseMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(AffineInverseMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -95,12 +135,7 @@ class CouplingRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(RQSCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -112,13 +147,7 @@ class MaskedAutoregressiveRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -126,12 +155,7 @@ class CouplingLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseShift(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(LRSCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseShift(event_shape=event_shape)) + bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -139,13 +163,7 @@ class MaskedAutoregressiveLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseShift(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - LRSForwardMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseShift(event_shape=event_shape)) + bijections = make_basic_layers(LRSForwardMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -153,13 +171,7 @@ class InverseAutoregressiveRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - RQSInverseMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(RQSInverseMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) @@ -167,12 +179,7 @@ class CouplingDSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - if edge_list is None: - bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(DSCoupling(event_shape=event_shape, edge_list=edge_list)) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -180,11 +187,5 @@ class UMNNMAF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 1, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.extend([ - ReversePermutation(event_shape=event_shape), - UMNNMaskedAutoregressive(event_shape=event_shape) - ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections = make_basic_layers(UMNNMaskedAutoregressive, event_shape, n_layers) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 1a163c4..2a5ec50 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -101,8 +101,27 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): super().__init__(event_shape, mask) -def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None): - if edge_list is None: - return HalfSplit(event_shape) - else: +def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling_type: str = 'half_split', **kwargs): + """ + + :param event_shape: + :param coupling_type: one of ['half_split', 'checkerboard', 'checkerboard_inverted', 'channel_wise', + 'channel_wise_inverted']. + :param edge_list: + :return: + """ + if edge_list is not None: return GraphicalCoupling(event_shape, edge_list) + else: + if coupling_type == 'half_split': + return HalfSplit(event_shape) + elif coupling_type == 'checkerboard': + return Checkerboard(event_shape, invert=False, **kwargs) + elif coupling_type == 'checkerboard_inverted': + return Checkerboard(event_shape, invert=True, **kwargs) + elif coupling_type == 'channel_wise': + raise NotImplementedError + elif coupling_type == 'channel_wise_inverted': + raise NotImplementedError + else: + raise ValueError diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index f963e15..197dcb6 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -26,7 +26,8 @@ def __init__(self, parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, global_parameter_mask: torch.Tensor = None, - initial_global_parameter_value: float = None): + initial_global_parameter_value: float = None, + **kwargs): """ :param input_event_shape: shape of conditioner input tensor x. :param context_shape: shape of conditioner context tensor c. diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 1594177..85e3bf1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -47,10 +47,13 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -66,10 +69,13 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = invert(Affine(event_shape=torch.Size((coupling.target_event_size,)))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -85,8 +91,11 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -103,9 +112,12 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): assert n_bins >= 1 - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -122,8 +134,11 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -140,8 +155,11 @@ def __init__(self, context_shape: torch.Size = None, n_hidden_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, **kwargs): - coupling = make_coupling(event_shape, edge_list) + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) transformer = DeepSigmoid( event_shape=torch.Size((coupling.target_event_size,)), n_hidden_layers=n_hidden_layers From 37c41ad845160c6fa9d6a5f243f359022f8caf8c Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 16 Jun 2024 13:11:30 +0200 Subject: [PATCH 25/50] Add channel-wise coupling --- .../conditioning/coupling_masks.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 2a5ec50..2c5078e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -86,7 +86,7 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): and smaller than image width. :param invert: invert the checkerboard mask. """ - n_channels, height, width = event_shape + height, width = event_shape[-2:] assert width % resolution == 0 square_side_length = width // resolution assert resolution % 2 == 0 @@ -96,8 +96,25 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): mask = mask.bool() if invert: mask = ~mask - self.source_shape = (n_channels, height // resolution, width // resolution) - self.target_shape = (n_channels, height // resolution, width // resolution) + super().__init__(event_shape, mask) + + +class ChannelWiseHalfSplit(Coupling): + """ + Channel-wise coupling for image data. + """ + + def __init__(self, event_shape, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param invert: invert the checkerboard mask. + """ + n_channels, height, width = event_shape + mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) + mask = mask[:, None, None].repeat(1, height, width) + if invert: + mask = ~mask super().__init__(event_shape, mask) @@ -120,8 +137,8 @@ def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling elif coupling_type == 'checkerboard_inverted': return Checkerboard(event_shape, invert=True, **kwargs) elif coupling_type == 'channel_wise': - raise NotImplementedError + return ChannelWiseHalfSplit(event_shape, invert=False) elif coupling_type == 'channel_wise_inverted': - raise NotImplementedError + return ChannelWiseHalfSplit(event_shape, invert=True) else: raise ValueError From 491a63eb9fd61beb152a5135b65a0c245dfde650 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 16 Jun 2024 13:12:00 +0200 Subject: [PATCH 26/50] Put layer construction into separate function --- .../finite/autoregressive/architectures.py | 119 +++++++++++++++--- 1 file changed, 101 insertions(+), 18 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index e366fbc..ef9c7c7 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -21,6 +21,24 @@ from normalizing_flows.bijections.finite.linear import ReversePermutation +def make_layers(base_bijection: Type[ + Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False): + if image_coupling: + if len(event_shape) == 2: + bijections = make_image_layers_single_channel(base_bijection, event_shape, n_layers) + elif len(event_shape) == 3: + bijections = make_image_layers_multichannel(base_bijection, event_shape, n_layers) + else: + raise ValueError + else: + bijections = make_basic_layers(base_bijection, event_shape, n_layers, edge_list) + return bijections + + def make_basic_layers(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], event_shape, @@ -38,13 +56,48 @@ def make_basic_layers(base_bijection: Type[ return bijections -def make_image_layers(base_bijection: Type[ +def make_image_layers_single_channel(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - checkerboard_resolution: int = 2, - n_layers: int = 2): + event_shape, + n_layers: int = 2, + checkerboard_resolution: int = 2): + """ + Returns a list of bijections for transformations of images with a single channel. + + Each layer consists of two coupling transforms: + 1. checkerboard, + 2. checkerboard_inverted. + """ + if len(event_shape) != 2: + raise ValueError("Single-channel image transformation are only possible for inputs with two axes.") + + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'checkerboard', + 'resolution': checkerboard_resolution, + } + )) + bijections.append(base_bijection( + event_shape=event_shape, + coupling_kwargs={ + 'coupling_type': 'checkerboard_inverted', + 'resolution': checkerboard_resolution, + } + )) + bijections.append(ElementwiseAffine(event_shape=event_shape)) + return bijections + + +def make_image_layers_multichannel(base_bijection: Type[ + Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], + event_shape, + n_layers: int = 2, + checkerboard_resolution: int = 2): """ - Returns a list of bijections for transformations of images. + Returns a list of bijections for transformations of images with multiple channels. Each layer consists of four coupling transforms: 1. checkerboard, @@ -53,7 +106,7 @@ def make_image_layers(base_bijection: Type[ 4. channel_wise_inverted. """ if len(event_shape) != 3: - raise ValueError("Image-based transformation are only possible for inputs with three axes.") + raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): @@ -88,26 +141,41 @@ def make_image_layers(base_bijection: Type[ class NICE(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(ShiftCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) class RealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(AffineCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) class InverseRealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(InverseAffineCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) @@ -132,10 +200,15 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingRQNSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(RQSCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) @@ -152,10 +225,15 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingLRS(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(LRSCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) @@ -176,10 +254,15 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + image_coupling: bool = False, + **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) + bijections = make_layers(DSCoupling, event_shape, n_layers, edge_list, image_coupling) super().__init__(event_shape, bijections, **kwargs) From 0b883d91153e3cc1875397259b5e271b4a72f7e2 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 20 Jun 2024 13:57:16 +0200 Subject: [PATCH 27/50] Add squeeze bijection --- normalizing_flows/bijections/finite/linear.py | 83 +++++++++++++++++++ test/test_squeeze_bijection.py | 20 +++++ 2 files changed, 103 insertions(+) create mode 100644 test/test_squeeze_bijection.py diff --git a/normalizing_flows/bijections/finite/linear.py b/normalizing_flows/bijections/finite/linear.py index 8b98c67..f9c3504 100644 --- a/normalizing_flows/bijections/finite/linear.py +++ b/normalizing_flows/bijections/finite/linear.py @@ -14,6 +14,89 @@ from normalizing_flows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch +class Squeeze(Bijection): + """ + Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape + (*batch_shape, 4 * channels, height / 2, width / 2). + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + # Check shape length + if len(event_shape) != 3: + raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") + # Check that height and width are divisible by two + if event_shape[1] % 2 != 0: + raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") + if event_shape[2] % 2 != 0: + raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") + super().__init__(event_shape, **kwargs) + c, h, w = event_shape + self.squeezed_event_shape = torch.Size((4 * c, h // 2, w // 2)) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape + (*batch_shape, 4 * channels, height // 2, width // 2). + """ + batch_shape = get_batch_shape(x, self.event_shape) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + + channels, height, width = x.shape[-3:] + assert height % 2 == 0 + assert width % 2 == 0 + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + + # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) + out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[channel_mask == i] = x[square_mask == i] + + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape + (*batch_shape, channels, height, width). + """ + batch_shape = get_batch_shape(z, self.squeezed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[square_mask == i] = z[channel_mask == i] + + return out, log_det + + class LinearBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix): super().__init__(event_shape) diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py new file mode 100644 index 0000000..b5db0b5 --- /dev/null +++ b/test/test_squeeze_bijection.py @@ -0,0 +1,20 @@ +import torch +import pytest +from normalizing_flows.bijections.finite.linear import Squeeze + + +@pytest.mark.parametrize('batch_shape', [(1,), (2,), (2, 3)]) +@pytest.mark.parametrize('channels', [1, 3, 10]) +@pytest.mark.parametrize('height', [4, 16, 32]) +@pytest.mark.parametrize('width', [4, 16, 32]) +def test_reconstruction(batch_shape, channels, height, width): + torch.manual_seed(0) + x = torch.randn(size=(*batch_shape, channels, height, width)) + layer = Squeeze(event_shape=x.shape[-3:]) + z, log_det_forward = layer.forward(x) + x_reconstructed, log_det_inverse = layer.inverse(z) + + assert z.shape == (*batch_shape, 4 * channels, height // 2, width // 2) + assert torch.allclose(x_reconstructed, x) + assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward)) + assert torch.allclose(log_det_forward, log_det_inverse) From 416525ef8fcbc1acbd2bd0b4e005905bec6bb9f4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 16:13:38 +0200 Subject: [PATCH 28/50] Add multiscale bijections and supporting classes --- normalizing_flows/architectures.py | 2 + normalizing_flows/bijections/base.py | 1 + .../finite/autoregressive/architectures.py | 120 +------------- .../conditioning/coupling_masks.py | 60 +------ .../finite/autoregressive/layers_base.py | 2 +- normalizing_flows/bijections/finite/linear.py | 83 ---------- .../bijections/finite/multiscale/__init__.py | 0 .../finite/multiscale/architectures.py | 38 +++++ .../bijections/finite/multiscale/base.py | 152 ++++++++++++++++++ .../bijections/finite/multiscale/coupling.py | 68 ++++++++ .../bijections/finite/multiscale/layers.py | 7 + normalizing_flows/flows.py | 12 +- test/test_checkerboard_coupling.py | 4 +- test/test_squeeze_bijection.py | 3 +- 14 files changed, 291 insertions(+), 261 deletions(-) create mode 100644 normalizing_flows/bijections/finite/multiscale/__init__.py create mode 100644 normalizing_flows/bijections/finite/multiscale/architectures.py create mode 100644 normalizing_flows/bijections/finite/multiscale/base.py create mode 100644 normalizing_flows/bijections/finite/multiscale/coupling.py create mode 100644 normalizing_flows/bijections/finite/multiscale/layers.py diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 191a022..90c9561 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -25,3 +25,5 @@ Radial, Sylvester ) + +from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a086705..8cb358d 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -20,6 +20,7 @@ def __init__(self, self.event_shape = event_shape self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) self.context_shape = context_shape + self.transformed_shape = self.event_shape # Overwritten in multiscale flows TODO make into property def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index ef9c7c7..d2ee0ef 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -21,24 +21,6 @@ from normalizing_flows.bijections.finite.linear import ReversePermutation -def make_layers(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False): - if image_coupling: - if len(event_shape) == 2: - bijections = make_image_layers_single_channel(base_bijection, event_shape, n_layers) - elif len(event_shape) == 3: - bijections = make_image_layers_multichannel(base_bijection, event_shape, n_layers) - else: - raise ValueError - else: - bijections = make_basic_layers(base_bijection, event_shape, n_layers, edge_list) - return bijections - - def make_basic_layers(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], event_shape, @@ -56,100 +38,15 @@ def make_basic_layers(base_bijection: Type[ return bijections -def make_image_layers_single_channel(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - checkerboard_resolution: int = 2): - """ - Returns a list of bijections for transformations of images with a single channel. - - Each layer consists of two coupling transforms: - 1. checkerboard, - 2. checkerboard_inverted. - """ - if len(event_shape) != 2: - raise ValueError("Single-channel image transformation are only possible for inputs with two axes.") - - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard_inverted', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(ElementwiseAffine(event_shape=event_shape)) - return bijections - - -def make_image_layers_multichannel(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - checkerboard_resolution: int = 2): - """ - Returns a list of bijections for transformations of images with multiple channels. - - Each layer consists of four coupling transforms: - 1. checkerboard, - 2. channel_wise, - 3. checkerboard_inverted, - 4. channel_wise_inverted. - """ - if len(event_shape) != 3: - raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") - - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'channel_wise' - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard_inverted', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'channel_wise_inverted' - } - )) - bijections.append(ElementwiseAffine(event_shape=event_shape)) - return bijections - - class NICE(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(ShiftCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -158,11 +55,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(AffineCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -171,11 +67,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(InverseAffineCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -204,11 +99,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(RQSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -229,11 +123,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(LRSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -258,11 +151,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(DSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 2c5078e..08b3d61 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -73,51 +73,6 @@ def __init__(self, event_shape): super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) -class Checkerboard(Coupling): - """ - Checkerboard coupling for image data. - """ - - def __init__(self, event_shape, resolution: int = 2, invert: bool = False): - """ - :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal - and a power of two. - :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two - and smaller than image width. - :param invert: invert the checkerboard mask. - """ - height, width = event_shape[-2:] - assert width % resolution == 0 - square_side_length = width // resolution - assert resolution % 2 == 0 - half_resolution = resolution // 2 - a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) - mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) - mask = mask.bool() - if invert: - mask = ~mask - super().__init__(event_shape, mask) - - -class ChannelWiseHalfSplit(Coupling): - """ - Channel-wise coupling for image data. - """ - - def __init__(self, event_shape, invert: bool = False): - """ - :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal - and a power of two. - :param invert: invert the checkerboard mask. - """ - n_channels, height, width = event_shape - mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) - mask = mask[:, None, None].repeat(1, height, width) - if invert: - mask = ~mask - super().__init__(event_shape, mask) - - def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling_type: str = 'half_split', **kwargs): """ @@ -129,16 +84,7 @@ def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling """ if edge_list is not None: return GraphicalCoupling(event_shape, edge_list) + elif coupling_type == 'half_split': + return HalfSplit(event_shape) else: - if coupling_type == 'half_split': - return HalfSplit(event_shape) - elif coupling_type == 'checkerboard': - return Checkerboard(event_shape, invert=False, **kwargs) - elif coupling_type == 'checkerboard_inverted': - return Checkerboard(event_shape, invert=True, **kwargs) - elif coupling_type == 'channel_wise': - return ChannelWiseHalfSplit(event_shape, invert=False) - elif coupling_type == 'channel_wise_inverted': - return ChannelWiseHalfSplit(event_shape, invert=True) - else: - raise ValueError + raise ValueError diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 5c43108..5187c3d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn diff --git a/normalizing_flows/bijections/finite/linear.py b/normalizing_flows/bijections/finite/linear.py index f9c3504..8b98c67 100644 --- a/normalizing_flows/bijections/finite/linear.py +++ b/normalizing_flows/bijections/finite/linear.py @@ -14,89 +14,6 @@ from normalizing_flows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch -class Squeeze(Bijection): - """ - Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape - (*batch_shape, 4 * channels, height / 2, width / 2). - """ - - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - # Check shape length - if len(event_shape) != 3: - raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") - # Check that height and width are divisible by two - if event_shape[1] % 2 != 0: - raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") - if event_shape[2] % 2 != 0: - raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") - super().__init__(event_shape, **kwargs) - c, h, w = event_shape - self.squeezed_event_shape = torch.Size((4 * c, h // 2, w // 2)) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape - (*batch_shape, 4 * channels, height // 2, width // 2). - """ - batch_shape = get_batch_shape(x, self.event_shape) - log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - - channels, height, width = x.shape[-3:] - assert height % 2 == 0 - assert width % 2 == 0 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - - # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) - out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[channel_mask == i] = x[square_mask == i] - - return out, log_det - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape - (*batch_shape, channels, height, width). - """ - batch_shape = get_batch_shape(z, self.squeezed_event_shape) - log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) - - four_channels, half_height, half_width = z.shape[-3:] - assert four_channels % 4 == 0 - width = 2 * half_width - height = 2 * half_height - channels = four_channels // 4 - - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[square_mask == i] = z[channel_mask == i] - - return out, log_det - - class LinearBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix): super().__init__(event_shape) diff --git a/normalizing_flows/bijections/finite/multiscale/__init__.py b/normalizing_flows/bijections/finite/multiscale/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py new file mode 100644 index 0000000..5ecff68 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -0,0 +1,38 @@ +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from normalizing_flows.bijections import BijectiveComposition +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection + + +def make_image_layers(event_shape, + transformer_class, + n_layers: int = 2): + """ + Returns a list of bijections for transformations of images with multiple channels. + """ + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class + ) + ) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +class MultiscaleRealNVP(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Affine, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py new file mode 100644 index 0000000..b379bd4 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -0,0 +1,152 @@ +from typing import Type, Union, Tuple + +import torch + +from normalizing_flows.bijections import BijectiveComposition, CouplingBijection +from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward +from normalizing_flows.bijections.base import Bijection +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling +from normalizing_flows.utils import get_batch_shape + + +class CheckerboardCoupling(CouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' + ) + transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) + + +class ChannelWiseCoupling(CouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' + ) + transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) + + +class Squeeze(Bijection): + """ + Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape + (*batch_shape, 4 * channels, height / 2, width / 2). + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + # Check shape length + if len(event_shape) != 3: + raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") + # Check that height and width are divisible by two + if event_shape[1] % 2 != 0: + raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") + if event_shape[2] % 2 != 0: + raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") + super().__init__(event_shape, **kwargs) + c, h, w = event_shape + self.transformed_event_shape = torch.Size((4 * c, h // 2, w // 2)) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape + (*batch_shape, 4 * channels, height // 2, width // 2). + """ + batch_shape = get_batch_shape(x, self.event_shape) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + + channels, height, width = x.shape[-3:] + assert height % 2 == 0 + assert width % 2 == 0 + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + + # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) + out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[channel_mask == i] = x[square_mask == i] + + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape + (*batch_shape, channels, height, width). + """ + batch_shape = get_batch_shape(z, self.transformed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[square_mask == i] = z[channel_mask == i] + + return out, log_det + + +class MultiscaleBijection(BijectiveComposition): + def __init__(self, + input_event_shape, + transformer_class: Type[TensorTransformer], + n_checkerboard_layers: int = 3, + n_channel_wise_layers: int = 3, + **kwargs): + checkerboard_layers = [ + CheckerboardCoupling(input_event_shape, transformer_class, alternate=i % 2 == 1) + for i in range(n_checkerboard_layers) + ] + squeeze_layer = Squeeze(input_event_shape) + channel_wise_layers = [ + ChannelWiseCoupling(squeeze_layer.transformed_event_shape, transformer_class, alternate=i % 2 == 1) + for i in range(n_channel_wise_layers) + ] + layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] + super().__init__(input_event_shape, layers, **kwargs) + self.transformed_shape = squeeze_layer.transformed_event_shape diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py new file mode 100644 index 0000000..c20a323 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -0,0 +1,68 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling + + +class Checkerboard(Coupling): + """ + Checkerboard coupling for image data. + """ + + def __init__(self, event_shape, resolution: int = 2, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two + and smaller than image width. + :param invert: invert the checkerboard mask. + """ + channels, height, width = event_shape[-3:] + assert width % resolution == 0 + square_side_length = width // resolution + assert resolution % 2 == 0 + half_resolution = resolution // 2 + a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) + mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) + mask = mask.bool() + mask = mask[None].repeat(channels, 1, 1) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + +class ChannelWiseHalfSplit(Coupling): + """ + Channel-wise coupling for image data. + """ + + def __init__(self, event_shape, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param invert: invert the checkerboard mask. + """ + n_channels, height, width = event_shape + mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) + mask = mask[:, None, None].repeat(1, height, width) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + +def make_image_coupling(event_shape, coupling_type: str, **kwargs): + """ + + :param event_shape: + :param coupling_type: one of ['checkerboard', 'checkerboard_inverted', 'channel_wise', 'channel_wise_inverted']. + :return: + """ + if coupling_type == 'checkerboard': + return Checkerboard(event_shape, invert=False, **kwargs) + elif coupling_type == 'checkerboard_inverted': + return Checkerboard(event_shape, invert=True, **kwargs) + elif coupling_type == 'channel_wise': + return ChannelWiseHalfSplit(event_shape, invert=False) + elif coupling_type == 'channel_wise_inverted': + return ChannelWiseHalfSplit(event_shape, invert=True) + else: + raise ValueError diff --git a/normalizing_flows/bijections/finite/multiscale/layers.py b/normalizing_flows/bijections/finite/multiscale/layers.py new file mode 100644 index 0000000..7a95207 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/layers.py @@ -0,0 +1,7 @@ +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift + + +class MultiscaleAffineCoupling(MultiscaleBijection): + def __init__(self, input_event_shape, **kwargs): + super().__init__(input_event_shape, transformer_class=Affine, **kwargs) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 465047c..602801b 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -275,18 +275,22 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re :param no_grad: if True, do not track gradients in the inverse pass. :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. """ + if context is not None: - z = self.base_sample(sample_shape=torch.Size((n, len(context)))) + sample_shape = torch.Size((n, len(context))) + z = self.base_sample(sample_shape=sample_shape) context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape assert z.shape[:2] == context.shape[:2] else: - z = self.base_sample(sample_shape=torch.Size((n,))) + sample_shape = torch.Size((n,)) + z = self.base_sample(sample_shape=sample_shape) + if no_grad: z = z.detach() with torch.no_grad(): - x, log_det = self.bijection.inverse(z, context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) else: - x, log_det = self.bijection.inverse(z, context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) x = x.to(self.get_device()) if return_log_prob: diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py index 1f55fdf..2f120c0 100644 --- a/test/test_checkerboard_coupling.py +++ b/test/test_checkerboard_coupling.py @@ -1,6 +1,8 @@ -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Checkerboard + import torch +from normalizing_flows.bijections.finite.multiscale.coupling import Checkerboard + def test_checkerboard_small(): torch.manual_seed(0) diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index b5db0b5..6952f34 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -1,6 +1,7 @@ import torch import pytest -from normalizing_flows.bijections.finite.linear import Squeeze + +from normalizing_flows.bijections.finite.multiscale.base import Squeeze @pytest.mark.parametrize('batch_shape', [(1,), (2,), (2, 3)]) From df85d58a76af78669fd8a1085f6ccc8a629f111d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 17:23:22 +0200 Subject: [PATCH 29/50] Add faster squeeze method --- .../bijections/finite/multiscale/base.py | 16 ++++++++++++++++ test/test_squeeze_bijection.py | 13 +++++++++++++ 2 files changed, 29 insertions(+) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index b379bd4..2258c47 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -98,6 +98,22 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return out, log_det + def forward2(self, x, context=None): + batch_shape = get_batch_shape(x, self.event_shape) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + + channels, height, width = x.shape[-3:] + assert height % 2 == 0 + assert width % 2 == 0 + + out = torch.concatenate([ + x[..., ::2, ::2], + x[..., ::2, 1::2], + x[..., 1::2, ::2], + x[..., 1::2, 1::2] + ], dim=1) + return out, log_det + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index 6952f34..a92eb87 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -19,3 +19,16 @@ def test_reconstruction(batch_shape, channels, height, width): assert torch.allclose(x_reconstructed, x) assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward)) assert torch.allclose(log_det_forward, log_det_inverse) + + +def test_efficient_forward(): + x = torch.tensor([ + [1, 2, 5, 6], + [3, 4, 7, 8], + [9, 10, 13, 14], + [11, 12, 15, 16] + ])[None, None] + layer = Squeeze(event_shape=x.shape[-3:]) + z, log_det_forward = layer.forward(x) + z2, log_det_forward2 = layer.forward2(x) + assert torch.allclose(z, z2) From eab2fece7ea6fd6647ff581b82f0a9bf015629f4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 19:21:14 +0200 Subject: [PATCH 30/50] Add alternative inverse --- .../bijections/finite/multiscale/base.py | 17 +++++++++++++++++ test/test_squeeze_bijection.py | 15 +++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index 2258c47..f073c62 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -146,6 +146,23 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return out, log_det + def inverse2(self, z, context=None): + batch_shape = get_batch_shape(z, self.transformed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + out[..., ::2, ::2] = z[..., 0, :, :] + out[..., ::2, 1::2] = z[..., 1, :, :] + out[..., 1::2, ::2] = z[..., 2, :, :] + out[..., 1::2, 1::2] = z[..., 3, :, :] + return out, log_det + class MultiscaleBijection(BijectiveComposition): def __init__(self, diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index a92eb87..86a3e39 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -32,3 +32,18 @@ def test_efficient_forward(): z, log_det_forward = layer.forward(x) z2, log_det_forward2 = layer.forward2(x) assert torch.allclose(z, z2) + + +def test_efficient_inverse(): + x = torch.tensor([ + [1, 2, 5, 6], + [3, 4, 7, 8], + [9, 10, 13, 14], + [11, 12, 15, 16] + ])[None, None] + layer = Squeeze(event_shape=x.shape[-3:]) + z, log_det_forward = layer.forward(x) + + xr, _ = layer.inverse2(z) + + assert torch.allclose(x, xr) From 39586c8a709ada586275af88f3dc4f26174fc1a1 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 19:28:26 +0200 Subject: [PATCH 31/50] Fix inverse and forward in squeeze --- .../bijections/finite/multiscale/base.py | 65 ++----------------- 1 file changed, 5 insertions(+), 60 deletions(-) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index f073c62..fad40f7 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -75,33 +75,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. batch_shape = get_batch_shape(x, self.event_shape) log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - channels, height, width = x.shape[-3:] - assert height % 2 == 0 - assert width % 2 == 0 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - - # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) - out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[channel_mask == i] = x[square_mask == i] - - return out, log_det - - def forward2(self, x, context=None): - batch_shape = get_batch_shape(x, self.event_shape) - log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - channels, height, width = x.shape[-3:] assert height % 2 == 0 assert width % 2 == 0 @@ -111,7 +84,7 @@ def forward2(self, x, context=None): x[..., ::2, 1::2], x[..., 1::2, ::2], x[..., 1::2, 1::2] - ], dim=1) + ], dim=-3) return out, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -128,39 +101,11 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. height = 2 * half_height channels = four_channels // 4 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[square_mask == i] = z[channel_mask == i] - - return out, log_det - - def inverse2(self, z, context=None): - batch_shape = get_batch_shape(z, self.transformed_event_shape) - log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) - - four_channels, half_height, half_width = z.shape[-3:] - assert four_channels % 4 == 0 - width = 2 * half_width - height = 2 * half_height - channels = four_channels // 4 - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - out[..., ::2, ::2] = z[..., 0, :, :] - out[..., ::2, 1::2] = z[..., 1, :, :] - out[..., 1::2, ::2] = z[..., 2, :, :] - out[..., 1::2, 1::2] = z[..., 3, :, :] + out[..., ::2, ::2] = z[..., 0:channels, :, :] + out[..., ::2, 1::2] = z[..., channels:2 * channels, :, :] + out[..., 1::2, ::2] = z[..., 2 * channels:3 * channels, :, :] + out[..., 1::2, 1::2] = z[..., 3 * channels:4 * channels, :, :] return out, log_det From ea1d5c5a4b8aeb5008bfbcbc0c225b6fd7e6b2f1 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 19:30:42 +0200 Subject: [PATCH 32/50] Remove old tests --- test/test_squeeze_bijection.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index 86a3e39..6952f34 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -19,31 +19,3 @@ def test_reconstruction(batch_shape, channels, height, width): assert torch.allclose(x_reconstructed, x) assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward)) assert torch.allclose(log_det_forward, log_det_inverse) - - -def test_efficient_forward(): - x = torch.tensor([ - [1, 2, 5, 6], - [3, 4, 7, 8], - [9, 10, 13, 14], - [11, 12, 15, 16] - ])[None, None] - layer = Squeeze(event_shape=x.shape[-3:]) - z, log_det_forward = layer.forward(x) - z2, log_det_forward2 = layer.forward2(x) - assert torch.allclose(z, z2) - - -def test_efficient_inverse(): - x = torch.tensor([ - [1, 2, 5, 6], - [3, 4, 7, 8], - [9, 10, 13, 14], - [11, 12, 15, 16] - ])[None, None] - layer = Squeeze(event_shape=x.shape[-3:]) - z, log_det_forward = layer.forward(x) - - xr, _ = layer.inverse2(z) - - assert torch.allclose(x, xr) From 72e67fd5a276f9419281619443b34ad25fe99a4a Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 24 Jun 2024 12:58:30 +0200 Subject: [PATCH 33/50] Generalize conditioner nonlinearities --- .../autoregressive/conditioning/transforms.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index 197dcb6..be50ef5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn @@ -208,6 +208,7 @@ def __init__(self, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, + nonlinearity: Type[nn.Module] = nn.Tanh, **kwargs): super().__init__( input_event_shape=input_event_shape, @@ -223,9 +224,9 @@ def __init__(self, if n_layers == 1: layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: - layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nonlinearity()]) for _ in range(n_layers - 2): - layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(n_hidden, n_hidden), nonlinearity()]) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError @@ -243,15 +244,15 @@ def __init__(self, *args, **kwargs): class ResidualFeedForward(ConditionerTransform): class ResidualBlock(nn.Module): - def __init__(self, event_size: int, hidden_size: int, block_size: int): + def __init__(self, event_size: int, hidden_size: int, block_size: int, nonlinearity: Type[nn.Module]): super().__init__() if block_size < 2: raise ValueError(f"block_size must be at least 2 but found {block_size}. " f"For block_size = 1, use the FeedForward class instead.") layers = [] - layers.extend([nn.Linear(event_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(event_size, hidden_size), nonlinearity()]) for _ in range(block_size - 2): - layers.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(hidden_size, hidden_size), nonlinearity()]) layers.extend([nn.Linear(hidden_size, event_size)]) self.sequential = nn.Sequential(*layers) @@ -265,6 +266,7 @@ def __init__(self, n_hidden: int = None, n_layers: int = 3, block_size: int = 2, + nonlinearity: Type[nn.Module] = nn.ReLU, **kwargs): super().__init__( input_event_shape=input_event_shape, @@ -279,9 +281,9 @@ def __init__(self, if n_layers <= 2: raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") - layers = [nn.Linear(self.n_input_event_dims, n_hidden), nn.ReLU()] + layers = [nn.Linear(self.n_input_event_dims, n_hidden), nonlinearity()] for _ in range(n_layers - 2): - layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size)) + layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size, nonlinearity=nonlinearity)) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) From 9c2620b253cb2df5ca25fdf370bceecf10f3594e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 24 Jun 2024 12:58:48 +0200 Subject: [PATCH 34/50] Use maximum resolution in multiscale coupling --- .../bijections/finite/multiscale/base.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index fad40f7..458975b 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -1,9 +1,10 @@ from typing import Type, Union, Tuple import torch +import torch.nn as nn from normalizing_flows.bijections import BijectiveComposition, CouplingBijection -from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward +from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward, ResidualFeedForward from normalizing_flows.bijections.base import Bijection from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling @@ -21,9 +22,10 @@ def __init__(self, coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' ) transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = FeedForward( + conditioner_transform = ResidualFeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), + nonlinearity=nn.Tanh, **kwargs ) super().__init__(transformer, coupling, conditioner_transform, **kwargs) @@ -40,9 +42,10 @@ def __init__(self, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' ) transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = FeedForward( + conditioner_transform = ResidualFeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), + nonlinearity=nn.Tanh, **kwargs ) super().__init__(transformer, coupling, conditioner_transform, **kwargs) @@ -115,16 +118,32 @@ def __init__(self, transformer_class: Type[TensorTransformer], n_checkerboard_layers: int = 3, n_channel_wise_layers: int = 3, + use_squeeze_layer: bool = True, **kwargs): + channels, height, width = input_event_shape[-3:] + resolution = min(width, height) // 2 checkerboard_layers = [ - CheckerboardCoupling(input_event_shape, transformer_class, alternate=i % 2 == 1) + CheckerboardCoupling( + input_event_shape, + transformer_class, + alternate=i % 2 == 1, + resolution=resolution + ) for i in range(n_checkerboard_layers) ] squeeze_layer = Squeeze(input_event_shape) channel_wise_layers = [ - ChannelWiseCoupling(squeeze_layer.transformed_event_shape, transformer_class, alternate=i % 2 == 1) + ChannelWiseCoupling( + squeeze_layer.transformed_event_shape, + transformer_class, + alternate=i % 2 == 1, + resolution=resolution + ) for i in range(n_channel_wise_layers) ] - layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] + if use_squeeze_layer: + layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] + else: + layers = [*checkerboard_layers, *channel_wise_layers] super().__init__(input_event_shape, layers, **kwargs) - self.transformed_shape = squeeze_layer.transformed_event_shape + self.transformed_shape = squeeze_layer.transformed_event_shape if use_squeeze_layer else input_event_shape From 152135e8dafcea3d2827bb844182eb9900fd75cd Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 24 Jun 2024 12:59:39 +0200 Subject: [PATCH 35/50] Minor change to image-based bijections --- .../bijections/finite/multiscale/architectures.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 5ecff68..21f7d8f 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -13,14 +13,25 @@ def make_image_layers(event_shape, if len(event_shape) != 3: raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + assert n_layers >= 1 + bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): + for _ in range(n_layers - 1): bijections.append( MultiscaleBijection( input_event_shape=bijections[-1].transformed_shape, transformer_class=transformer_class ) ) + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class, + n_checkerboard_layers=4, + squeeze_layer=False, + n_channel_wise_layers=0 + ) + ) bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) return bijections @@ -35,4 +46,3 @@ def __init__(self, bijections = make_image_layers(event_shape, Affine, n_layers) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape - From cbaa9d73da0ef1e6227a6b21c88eb6901f16ec7e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 24 Jun 2024 15:15:56 +0200 Subject: [PATCH 36/50] Refactor coupling bijections for CNN conditioner transformers --- .../finite/autoregressive/layers_base.py | 21 ++++-- .../bijections/finite/multiscale/base.py | 69 ++++++++++++++----- .../bijections/finite/multiscale/coupling.py | 26 ++++++- test/test_channel_wise_coupling.py | 19 +++++ test/test_checkerboard_coupling.py | 25 ++++++- 5 files changed, 132 insertions(+), 28 deletions(-) create mode 100644 test/test_channel_wise_coupling.py diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 5187c3d..2bb9c77 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -64,6 +64,15 @@ def __init__(self, assert conditioner_transform.input_event_shape == (coupling.source_event_size,) assert transformer.event_shape == (self.coupling.target_event_size,) + def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: + return x[..., self.coupling.source_mask] + + def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: + return x[..., self.coupling.target_mask] + + def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): + x[..., self.coupling.target_mask] = x_transformed + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): """ Partition tensor x and compute transformer parameters. @@ -73,26 +82,28 @@ def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tenso :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). """ # Predict transformer parameters for output dimensions - x_a = x[..., self.coupling.source_mask] # (*b, constant_event_size) + x_a = self.get_constant_part(x) # (*b, constant_event_size) h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() h_b = self.partition_and_predict_parameters(x, context) - z[..., self.coupling.target_mask], log_det = self.transformer.forward( - x[..., self.coupling.target_mask], + z_transformed, log_det = self.transformer.forward( + self.get_transformed_part(x), h_b ) + self.set_transformed_part(z, z_transformed) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() h_b = self.partition_and_predict_parameters(x, context) - x[..., self.coupling.target_mask], log_det = self.transformer.inverse( - z[..., self.coupling.target_mask], + x_transformed, log_det = self.transformer.inverse( + self.get_transformed_part(z), h_b ) + self.set_transformed_part(x, x_transformed) return x, log_det diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index 458975b..a48de57 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -3,15 +3,58 @@ import torch import torch.nn as nn -from normalizing_flows.bijections import BijectiveComposition, CouplingBijection -from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward, ResidualFeedForward +from normalizing_flows.bijections import BijectiveComposition +from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform from normalizing_flows.bijections.base import Bijection +from normalizing_flows.bijections.finite.autoregressive.layers_base import CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer -from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling +from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \ + ChannelWiseHalfSplit from normalizing_flows.utils import get_batch_shape -class CheckerboardCoupling(CouplingBijection): +class ResNet(ConditionerTransform): + pass + + +class ConvolutionalCouplingBijection(CouplingBijection): + def __init__(self, + transformer: TensorTransformer, + coupling: Union[Checkerboard, ChannelWiseHalfSplit], + **kwargs): + conditioner_transform = ResNet() + super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + self.coupling = coupling + + def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: + """ + + :param x: tensor with shape (*b, channels, height, width). + :return: tensor with shape (*b, constant_channels, constant_height, constant_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.source_mask].view(*batch_shape, *self.coupling.constant_shape) + + def get_transformed_part(self, x: torch.Tensor) -> torch.Tensor: + """ + + :param x: tensor with shape (*b, channels, height, width). + :return: tensor with shape (*b, transformed_channels, transformed_height, constant_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + + def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): + """ + + :param x: tensor with shape (*b, channels, height, width). + :param x_transformed: tensor with shape (*b, transformed_channels, transformed_height, transformed_width). + """ + batch_shape = get_batch_shape(x, self.event_shape) + return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + + +class CheckerboardCoupling(ConvolutionalCouplingBijection): def __init__(self, event_shape, transformer_class: Type[TensorTransformer], @@ -22,16 +65,10 @@ def __init__(self, coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' ) transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = ResidualFeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - nonlinearity=nn.Tanh, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform, **kwargs) + super().__init__(transformer, coupling, **kwargs) -class ChannelWiseCoupling(CouplingBijection): +class ChannelWiseCoupling(ConvolutionalCouplingBijection): def __init__(self, event_shape, transformer_class: Type[TensorTransformer], @@ -42,13 +79,7 @@ def __init__(self, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' ) transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = ResidualFeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - nonlinearity=nn.Tanh, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform, **kwargs) + super().__init__(transformer, coupling, **kwargs) class Squeeze(Bijection): diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py index c20a323..6470086 100644 --- a/normalizing_flows/bijections/finite/multiscale/coupling.py +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -16,7 +16,7 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): and smaller than image width. :param invert: invert the checkerboard mask. """ - channels, height, width = event_shape[-3:] + channels, height, width = event_shape assert width % resolution == 0 square_side_length = width // resolution assert resolution % 2 == 0 @@ -24,11 +24,21 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) mask = mask.bool() - mask = mask[None].repeat(channels, 1, 1) + mask = mask[None].repeat(channels, 1, 1) # (channels, height, width) if invert: mask = ~mask + self.resolution = resolution super().__init__(event_shape, mask) + @property + def constant_shape(self): + n_channels, _, _ = self.event_shape + return n_channels, self.resolution, self.resolution + + @property + def transformed_shape(self): + return self.constant_shape + class ChannelWiseHalfSplit(Coupling): """ @@ -43,11 +53,21 @@ def __init__(self, event_shape, invert: bool = False): """ n_channels, height, width = event_shape mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) - mask = mask[:, None, None].repeat(1, height, width) + mask = mask[:, None, None].repeat(1, height, width) # (channels, height, width) if invert: mask = ~mask super().__init__(event_shape, mask) + @property + def constant_shape(self): + n_channels, height, width = self.event_shape + return n_channels // 2, height, width + + @property + def transformed_shape(self): + n_channels, height, width = self.event_shape + return n_channels - n_channels // 2, height, width + def make_image_coupling(event_shape, coupling_type: str, **kwargs): """ diff --git a/test/test_channel_wise_coupling.py b/test/test_channel_wise_coupling.py new file mode 100644 index 0000000..48bda61 --- /dev/null +++ b/test/test_channel_wise_coupling.py @@ -0,0 +1,19 @@ +import torch + +from normalizing_flows.bijections.finite.multiscale.coupling import ChannelWiseHalfSplit + + +def test_partition_shapes_1(): + torch.manual_seed(0) + image_shape = (3, 4, 4) + coupling = ChannelWiseHalfSplit(image_shape, invert=True) + assert coupling.constant_shape == (1, 4, 4) + assert coupling.transformed_shape == (2, 4, 4) + + +def test_partition_shapes_2(): + torch.manual_seed(0) + image_shape = (3, 16, 16) + coupling = ChannelWiseHalfSplit(image_shape, invert=True) + assert coupling.constant_shape == (1, 16, 16) + assert coupling.transformed_shape == (2, 16, 16) diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py index 2f120c0..00b1dd4 100644 --- a/test/test_checkerboard_coupling.py +++ b/test/test_checkerboard_coupling.py @@ -1,4 +1,3 @@ - import torch from normalizing_flows.bijections.finite.multiscale.coupling import Checkerboard @@ -62,3 +61,27 @@ def test_checkerboard_small_inverted(): ], dtype=torch.bool)[None].repeat(3, 1, 1) ) assert torch.allclose(coupling.target_mask, ~coupling.source_mask) + + +def test_partition_shapes_1(): + torch.manual_seed(0) + image_shape = (3, 4, 4) + coupling = Checkerboard(image_shape, resolution=2, invert=True) + assert coupling.constant_shape == (3, 2, 2) + assert coupling.transformed_shape == (3, 2, 2) + + +def test_partition_shapes_2(): + torch.manual_seed(0) + image_shape = (3, 16, 16) + coupling = Checkerboard(image_shape, resolution=8, invert=True) + assert coupling.constant_shape == (3, 8, 8) + assert coupling.transformed_shape == (3, 8, 8) + + +def test_partition_shapes_3(): + torch.manual_seed(0) + image_shape = (3, 16, 8) + coupling = Checkerboard(image_shape, resolution=4, invert=True) + assert coupling.constant_shape == (3, 4, 4) + assert coupling.transformed_shape == (3, 4, 4) From 7a028bf25ee7778a9eabc8c4cea3121eb8996431 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 24 Jun 2024 15:50:19 +0200 Subject: [PATCH 37/50] Add resnets --- normalizing_flows/neural_networks/__init__.py | 0 normalizing_flows/resnet.py | 145 ++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 normalizing_flows/neural_networks/__init__.py create mode 100644 normalizing_flows/resnet.py diff --git a/normalizing_flows/neural_networks/__init__.py b/normalizing_flows/neural_networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/resnet.py b/normalizing_flows/resnet.py new file mode 100644 index 0000000..0823cd3 --- /dev/null +++ b/normalizing_flows/resnet.py @@ -0,0 +1,145 @@ +# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, image_width, block, num_blocks, n_hidden=100, n_outputs=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear1 = nn.Linear(512 * block.expansion * (image_width // 32) ** 2, n_hidden) + self.linear2 = nn.Linear(n_hidden, n_outputs) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + """ + :param x: tensor with shape (*b, channels, height, width). Height and width must be equal. + :return: + """ + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.flatten(start_dim=1, end_dim=-1) + out = self.linear1(out) + out = self.linear2(out) + return out + + +def make_resnet18(image_width, n_outputs): + return ResNet(image_width, BasicBlock, [2, 2, 2, 2], n_outputs=n_outputs) + + +def make_resnet34(image_width, n_outputs): + return ResNet(image_width, BasicBlock, [3, 4, 6, 3], n_outputs=n_outputs) + + +def make_resnet50(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 4, 6, 3], n_outputs=n_outputs) + + +def make_resnet101(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 4, 23, 3], n_outputs=n_outputs) + + +def make_resnet152(image_width, n_outputs): + return ResNet(image_width, Bottleneck, [3, 8, 36, 3], n_outputs=n_outputs) + + +if __name__ == '__main__': + n_images = 2 + image_shape = (3, 270, 270) + + net = make_resnet18(image_width=image_shape[-1], n_outputs=15) + y = net(torch.randn(n_images, *image_shape)) + print(y.size()) From 43fe31576ba5c441a34d9fa2c94c97bf20d7f39f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 25 Jun 2024 11:45:26 +0200 Subject: [PATCH 38/50] Add convnet conditioner --- normalizing_flows/neural_networks/convnet.py | 76 ++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 normalizing_flows/neural_networks/convnet.py diff --git a/normalizing_flows/neural_networks/convnet.py b/normalizing_flows/neural_networks/convnet.py new file mode 100644 index 0000000..484cb55 --- /dev/null +++ b/normalizing_flows/neural_networks/convnet.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +class ConvNet(nn.Module): + class ConvNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, input_height, input_width, use_pooling: bool = True): + super().__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) + self.bn = nn.BatchNorm2d(out_channels) + self.pool = nn.MaxPool2d(2) if use_pooling else nn.Identity() + + if use_pooling: + self.output_shape = (out_channels, input_height // 2, input_width // 2) + else: + self.output_shape = (out_channels, input_height, input_width) + + def forward(self, x): + return self.bn(self.pool(torch.relu(self.conv(x)))) + + def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None): + """ + + :param input_shape: (channels, height, width) + :param n_outputs: + """ + super().__init__() + channels, height, width = input_shape + + if kernels is None: + kernels = (4, 8, 16, 24) + else: + assert len(kernels) >= 1 + + blocks = [ + self.ConvNetBlock( + in_channels=channels, + out_channels=kernels[0], + input_height=height, + input_width=width, + use_pooling=min(height, width) >= 2 + ) + ] + for i in range(len(kernels) - 1): + blocks.append( + self.ConvNetBlock( + in_channels=kernels[i], + out_channels=kernels[i + 1], + input_height=blocks[i].output_shape[1], + input_width=blocks[i].output_shape[2], + use_pooling=min(blocks[i].output_shape[1], blocks[i].output_shape[2]) >= 2 + ) + ) + self.blocks = nn.ModuleList(blocks) + self.linear = nn.Linear( + in_features=int(torch.prod(torch.as_tensor(self.blocks[-1].output_shape))), + out_features=n_outputs + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.flatten(start_dim=1, end_dim=-1) + x = self.linear(x) + return x + + +if __name__ == '__main__': + torch.manual_seed(0) + image_shape = (1, 36, 29) + images = torch.randn(size=(11, *image_shape)) + net = ConvNet(input_shape=image_shape, n_outputs=77) + out = net(images) + print(f'{out.shape = }') From 480c2ff2b225b7a9dbfdfd7db74ffad867145c8c Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 25 Jun 2024 11:45:57 +0200 Subject: [PATCH 39/50] Remove shape assertion in coupling bijection --- .../bijections/finite/autoregressive/layers_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 2bb9c77..fe20d30 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -61,9 +61,6 @@ def __init__(self, super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) self.coupling = coupling - assert conditioner_transform.input_event_shape == (coupling.source_event_size,) - assert transformer.event_shape == (self.coupling.target_event_size,) - def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: return x[..., self.coupling.source_mask] From 58459cfa6ebe47de684a1a0a2d3381bed4389bbc Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 25 Jun 2024 11:46:19 +0200 Subject: [PATCH 40/50] Change image shape in resnet test case --- normalizing_flows/{ => neural_networks}/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename normalizing_flows/{ => neural_networks}/resnet.py (99%) diff --git a/normalizing_flows/resnet.py b/normalizing_flows/neural_networks/resnet.py similarity index 99% rename from normalizing_flows/resnet.py rename to normalizing_flows/neural_networks/resnet.py index 0823cd3..9b6c8a9 100644 --- a/normalizing_flows/resnet.py +++ b/normalizing_flows/neural_networks/resnet.py @@ -138,7 +138,7 @@ def make_resnet152(image_width, n_outputs): if __name__ == '__main__': n_images = 2 - image_shape = (3, 270, 270) + image_shape = (3, 32, 32) net = make_resnet18(image_width=image_shape[-1], n_outputs=15) y = net(torch.randn(n_images, *image_shape)) From 96122ba9f8d853d2342e849df3a3a54e252785fa Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 25 Jun 2024 11:46:51 +0200 Subject: [PATCH 41/50] Rework image coupling layers, remove checkerboard coupling test --- .../bijections/finite/multiscale/base.py | 50 ++++++++--- .../bijections/finite/multiscale/coupling.py | 17 +--- test/test_checkerboard_coupling.py | 87 ------------------- 3 files changed, 41 insertions(+), 113 deletions(-) delete mode 100644 test/test_checkerboard_coupling.py diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index a48de57..aa7b844 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -1,7 +1,6 @@ from typing import Type, Union, Tuple import torch -import torch.nn as nn from normalizing_flows.bijections import BijectiveComposition from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform @@ -10,11 +9,30 @@ from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \ ChannelWiseHalfSplit +from normalizing_flows.neural_networks.convnet import ConvNet from normalizing_flows.utils import get_batch_shape -class ResNet(ConditionerTransform): - pass +class ConvNetConditioner(ConditionerTransform): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + kernels: Tuple[int, ...] = None, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=None, + parameter_shape=parameter_shape, + **kwargs + ) + self.network = ConvNet( + input_shape=input_event_shape, + n_outputs=self.n_transformer_parameters, + kernels=kernels, + ) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.network(x) class ConvolutionalCouplingBijection(CouplingBijection): @@ -22,8 +40,12 @@ def __init__(self, transformer: TensorTransformer, coupling: Union[Checkerboard, ChannelWiseHalfSplit], **kwargs): - conditioner_transform = ResNet() - super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + conditioner_transform = ConvNetConditioner( + input_event_shape=coupling.constant_shape, + parameter_shape=transformer.parameter_shape, + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) self.coupling = coupling def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: @@ -53,6 +75,12 @@ def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): batch_shape = get_batch_shape(x, self.event_shape) return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): + batch_shape = get_batch_shape(x, self.event_shape) + super_out = super().partition_and_predict_parameters(x, context) + return super_out.view(*batch_shape, *self.coupling.transformed_shape, + *self.transformer.parameter_shape_per_element) + class CheckerboardCoupling(ConvolutionalCouplingBijection): def __init__(self, @@ -64,7 +92,7 @@ def __init__(self, event_shape, coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' ) - transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + transformer = transformer_class(event_shape=coupling.transformed_shape) super().__init__(transformer, coupling, **kwargs) @@ -78,7 +106,7 @@ def __init__(self, event_shape, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' ) - transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + transformer = transformer_class(event_shape=coupling.transformed_shape) super().__init__(transformer, coupling, **kwargs) @@ -151,14 +179,11 @@ def __init__(self, n_channel_wise_layers: int = 3, use_squeeze_layer: bool = True, **kwargs): - channels, height, width = input_event_shape[-3:] - resolution = min(width, height) // 2 checkerboard_layers = [ CheckerboardCoupling( input_event_shape, transformer_class, - alternate=i % 2 == 1, - resolution=resolution + alternate=i % 2 == 1 ) for i in range(n_checkerboard_layers) ] @@ -167,8 +192,7 @@ def __init__(self, ChannelWiseCoupling( squeeze_layer.transformed_event_shape, transformer_class, - alternate=i % 2 == 1, - resolution=resolution + alternate=i % 2 == 1 ) for i in range(n_channel_wise_layers) ] diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py index 6470086..33bf6cd 100644 --- a/normalizing_flows/bijections/finite/multiscale/coupling.py +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -8,32 +8,23 @@ class Checkerboard(Coupling): Checkerboard coupling for image data. """ - def __init__(self, event_shape, resolution: int = 2, invert: bool = False): + def __init__(self, event_shape, invert: bool = False): """ :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. - :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two - and smaller than image width. :param invert: invert the checkerboard mask. """ channels, height, width = event_shape - assert width % resolution == 0 - square_side_length = width // resolution - assert resolution % 2 == 0 - half_resolution = resolution // 2 - a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) - mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) - mask = mask.bool() + mask = (torch.arange(height * width) % 2).view(height, width).bool() mask = mask[None].repeat(channels, 1, 1) # (channels, height, width) if invert: mask = ~mask - self.resolution = resolution super().__init__(event_shape, mask) @property def constant_shape(self): - n_channels, _, _ = self.event_shape - return n_channels, self.resolution, self.resolution + n_channels, height, width = self.event_shape + return n_channels, height // 2, width # rectangular shape @property def transformed_shape(self): diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py deleted file mode 100644 index 00b1dd4..0000000 --- a/test/test_checkerboard_coupling.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch - -from normalizing_flows.bijections.finite.multiscale.coupling import Checkerboard - - -def test_checkerboard_small(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2) - assert torch.allclose( - coupling.source_mask, - torch.tensor([ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_checkerboard_medium(): - torch.manual_seed(0) - image_shape = (3, 16, 16) - coupling = Checkerboard(image_shape, resolution=4) - assert torch.allclose( - coupling.source_mask, - torch.tensor([ - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_checkerboard_small_inverted(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2, invert=True) - assert torch.allclose( - coupling.source_mask, - ~torch.tensor([ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_partition_shapes_1(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2, invert=True) - assert coupling.constant_shape == (3, 2, 2) - assert coupling.transformed_shape == (3, 2, 2) - - -def test_partition_shapes_2(): - torch.manual_seed(0) - image_shape = (3, 16, 16) - coupling = Checkerboard(image_shape, resolution=8, invert=True) - assert coupling.constant_shape == (3, 8, 8) - assert coupling.transformed_shape == (3, 8, 8) - - -def test_partition_shapes_3(): - torch.manual_seed(0) - image_shape = (3, 16, 8) - coupling = Checkerboard(image_shape, resolution=4, invert=True) - assert coupling.constant_shape == (3, 4, 4) - assert coupling.transformed_shape == (3, 4, 4) From d9aba8cafc02f0e39a33ba3c6037ea4896a1a731 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 27 Jun 2024 07:58:25 +0200 Subject: [PATCH 42/50] Add more multiscale architectures --- normalizing_flows/architectures.py | 7 +++- .../finite/multiscale/architectures.py | 40 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 90c9561..b441edf 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -26,4 +26,9 @@ Sylvester ) -from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP +from normalizing_flows.bijections.finite.multiscale.architectures import ( + MultiscaleRealNVP, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE +) diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 21f7d8f..8f5a49e 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -1,5 +1,7 @@ from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine -from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational from normalizing_flows.bijections import BijectiveComposition from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection @@ -46,3 +48,39 @@ def __init__(self, bijections = make_image_layers(event_shape, Affine, n_layers) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleNICE(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Shift, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleRQNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, RationalQuadratic, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleLRSNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, LinearRational, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape From ecb9140bad8b562b33b2643b594325092285a904 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 4 Jul 2024 15:34:48 +0200 Subject: [PATCH 43/50] Add factored bijection --- .../bijections/finite/multiscale/base.py | 59 +++++++++++++++++++ test/test_factored_bijection.py | 42 +++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 test/test_factored_bijection.py diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index aa7b844..a3a8d21 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -13,6 +13,65 @@ from normalizing_flows.utils import get_batch_shape +class FactoredBijection(Bijection): + """ + Factored bijection class. + + Partitions the input tensor x into parts x_A and x_B, then applies a bijection to x_A independently of x_B while + keeping x_B identical. + """ + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + transformed_event_shape: Union[torch.Size, Tuple[int, ...]], + small_bijection: Bijection, + transformed_event_mask: torch.Tensor, + **kwargs): + """ + + :param event_shape: shape of input event x. + :param transformed_event_shape: shape of transformed event x_A. + :param constant_event_shape: shape of constant event x_B. + :param small_bijection: bijection applied to transformed event x_A. + :param transformed_event_mask: boolean mask that selects which elements of event x correspond to the transformed + event x_A. + :param kwargs: + """ + super().__init__(event_shape, **kwargs) + + # Check that shapes are correct + event_size = torch.prod(torch.as_tensor(event_shape)) + transformed_event_size = torch.prod(torch.as_tensor(transformed_event_shape)) + assert event_size >= transformed_event_size + + assert transformed_event_mask.shape == event_shape + assert small_bijection.event_shape == transformed_event_shape + + self.transformed_event_mask = transformed_event_mask + self.transformed_event_shape = transformed_event_shape + self.small_bijection = small_bijection + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(x, self.event_shape) + transformed, log_det = self.small_bijection.forward( + x[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + context + ) + out = x.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(z, self.event_shape) + transformed, log_det = self.small_bijection.inverse( + z[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + context + ) + out = z.clone() + out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) + return out, log_det + + class ConvNetConditioner(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py new file mode 100644 index 0000000..6df7726 --- /dev/null +++ b/test/test_factored_bijection.py @@ -0,0 +1,42 @@ +import torch +from normalizing_flows.bijections.finite.multiscale.base import FactoredBijection +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine + + +def test_basic(): + torch.manual_seed(0) + + bijection = FactoredBijection( + event_shape=(6, 6), + transformed_event_shape=(3, 3), + transformed_event_mask=torch.tensor([ + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]), + small_bijection=ElementwiseAffine(event_shape=(3, 3)) + ) + + x = torch.randn(100, *bijection.event_shape) + z, log_det_forward = bijection.forward(x) + + assert torch.allclose( + x[..., ~bijection.transformed_event_mask], + z[..., ~bijection.transformed_event_mask], + atol=1e-5 + ) + + assert torch.all( + ~torch.isclose( + x[..., bijection.transformed_event_mask], + z[..., bijection.transformed_event_mask], + atol=1e-5 + ) + ) + + xr, log_det_inverse = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-5) + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-5) From 23d319e26ca52dfc23f5b4a6c502490b98519acb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 4 Jul 2024 15:49:52 +0200 Subject: [PATCH 44/50] Fix validation loss computation --- normalizing_flows/flows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 602801b..f50b1b8 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -161,8 +161,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss val_loss = 0.0 for val_batch in val_loader: - n_batch_data = len(val_batch[0]) - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) + val_loss /= len(x_val) val_loss += self.regularization() # Check if validation loss is the lowest so far From d0c5e3cdf2966330fe3627de3d63ed7532970961 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 5 Jul 2024 17:33:23 +0200 Subject: [PATCH 45/50] Add factored bijections (drop out half the channels at each layer) --- .../finite/multiscale/architectures.py | 91 +++++++++++++++++-- .../bijections/finite/multiscale/base.py | 19 ++-- test/test_factored_bijection.py | 4 +- test/test_multiscale_bijections.py | 53 +++++++++++ 4 files changed, 145 insertions(+), 22 deletions(-) create mode 100644 test/test_multiscale_bijections.py diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 8f5a49e..6bfabb6 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -1,14 +1,76 @@ +import torch + from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational from normalizing_flows.bijections import BijectiveComposition -from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection, FactoredBijection +import math + + +def make_factored_image_layers(event_shape, + transformer_class, + n_layers: int = 2): + """ + Creates a list of image transformations consisting of coupling layers and squeeze layers. + After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore). + + :param event_shape: (c, 2^n, 2^m). + :param transformer_class: + :param n_layers: + :return: + """ + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + if bin(event_shape[1]).count("1") != 1: + raise ValueError("Image height must be a power of two.") + if bin(event_shape[2]).count("1") != 1: + raise ValueError("Image width must be a power of two.") + if n_layers < 1: + raise ValueError + + log_height = math.log2(event_shape[1]) + log_width = math.log2(event_shape[2]) + if n_layers > min(log_height, log_width): + raise ValueError("Too many layers for input image size") + + def recursive_layer_builder(event_shape_, n_layers_): + msb = MultiscaleBijection( + input_event_shape=event_shape_, + transformer_class=transformer_class + ) + if n_layers_ == 1: + return msb + c, h, w = msb.transformed_shape # c is a multiple of 4 after squeezing -def make_image_layers(event_shape, - transformer_class, - n_layers: int = 2): + small_bijection_shape = (c // 2, h, w) + small_bijection_mask = (torch.arange(c) >= c // 2)[:, None, None].repeat(1, h, w) + fb = FactoredBijection( + event_shape=(c, h, w), + small_bijection=recursive_layer_builder( + event_shape_=small_bijection_shape, + n_layers_=n_layers_ - 1 + ), + small_bijection_mask=small_bijection_mask + ) + composition = BijectiveComposition( + event_shape=msb.event_shape, + layers=[msb, fb] + ) + composition.transformed_shape = fb.transformed_shape + return composition + + bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections.append(recursive_layer_builder(bijections[-1].transformed_shape, n_layers)) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +def make_image_layers_non_factored(event_shape, + transformer_class, + n_layers: int = 2): """ Returns a list of bijections for transformations of images with multiple channels. """ @@ -17,6 +79,8 @@ def make_image_layers(event_shape, assert n_layers >= 1 + # TODO check that image shape is big enough for this number of layers (divisibility by 2) + bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers - 1): bijections.append( @@ -38,14 +102,22 @@ def make_image_layers(event_shape, return bijections +def make_image_layers(*args, factored: bool = False, **kwargs): + if factored: + return make_factored_image_layers(*args, **kwargs) + else: + return make_image_layers_non_factored(*args, **kwargs) + + class MultiscaleRealNVP(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Affine, n_layers) + bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -54,10 +126,11 @@ class MultiscaleNICE(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Shift, n_layers) + bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -66,10 +139,11 @@ class MultiscaleRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, RationalQuadratic, n_layers) + bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -78,9 +152,10 @@ class MultiscaleLRSNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, LinearRational, n_layers) + bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index a3a8d21..4298710 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -23,17 +23,14 @@ class FactoredBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - transformed_event_shape: Union[torch.Size, Tuple[int, ...]], small_bijection: Bijection, - transformed_event_mask: torch.Tensor, + small_bijection_mask: torch.Tensor, **kwargs): """ :param event_shape: shape of input event x. - :param transformed_event_shape: shape of transformed event x_A. - :param constant_event_shape: shape of constant event x_B. :param small_bijection: bijection applied to transformed event x_A. - :param transformed_event_mask: boolean mask that selects which elements of event x correspond to the transformed + :param small_bijection_mask: boolean mask that selects which elements of event x correspond to the transformed event x_A. :param kwargs: """ @@ -41,20 +38,18 @@ def __init__(self, # Check that shapes are correct event_size = torch.prod(torch.as_tensor(event_shape)) - transformed_event_size = torch.prod(torch.as_tensor(transformed_event_shape)) + transformed_event_size = torch.prod(torch.as_tensor(small_bijection.event_shape)) assert event_size >= transformed_event_size - assert transformed_event_mask.shape == event_shape - assert small_bijection.event_shape == transformed_event_shape + assert small_bijection_mask.shape == event_shape - self.transformed_event_mask = transformed_event_mask - self.transformed_event_shape = transformed_event_shape + self.transformed_event_mask = small_bijection_mask self.small_bijection = small_bijection def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(x, self.event_shape) transformed, log_det = self.small_bijection.forward( - x[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + x[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.event_shape), context ) out = x.clone() @@ -64,7 +59,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) transformed, log_det = self.small_bijection.inverse( - z[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + z[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.transformed_shape), context ) out = z.clone() diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py index 6df7726..4c43cf7 100644 --- a/test/test_factored_bijection.py +++ b/test/test_factored_bijection.py @@ -8,8 +8,8 @@ def test_basic(): bijection = FactoredBijection( event_shape=(6, 6), - transformed_event_shape=(3, 3), - transformed_event_mask=torch.tensor([ + small_bijection_event_shape=(3, 3), + small_bijection_mask=torch.tensor([ [True, True, True, False, False, False], [True, True, True, False, False, False], [True, True, True, False, False, False], diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py new file mode 100644 index 0000000..fa1b181 --- /dev/null +++ b/test/test_multiscale_bijections.py @@ -0,0 +1,53 @@ +from normalizing_flows.architectures import MultiscaleNICE, MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleRealNVP +import torch +import pytest + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=False) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=3, factored=False) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 32, 32), (3, 32, 32)]) +def test_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=True) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 15, 32), (3, 15, 32)]) +def test_factored_wrong_shape(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=2, factored=True) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 8, 8), (3, 8, 8)]) +def test_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=8, factored=True) From fba246807b2bb17570a1b7dff75cd401ee748822 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 6 Jul 2024 15:08:59 +0200 Subject: [PATCH 46/50] Automatically determine number of layers for multiscale flows if no argument provided --- .../finite/multiscale/architectures.py | 78 ++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 6bfabb6..cc930e0 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -4,14 +4,48 @@ from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational +from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import ( + DeepSigmoid, + DeepDenseSigmoid, + DenseSigmoid +) from normalizing_flows.bijections import BijectiveComposition from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection, FactoredBijection -import math + + +def check_image_shape_for_multiscale_flow(event_shape, n_layers): + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with 3 axes.") + if event_shape[1] % 2 != 0 or event_shape[2] % 2 != 0: + raise ValueError("Image height and width must be divisible by 2.") + if n_layers is not None and n_layers < 1: + raise ValueError("Need at least one layer for multiscale flow.") + + # Check image height and width + if n_layers is not None: + if event_shape[1] % (2 ** n_layers) != 0: + raise ValueError("Image height must be divisible by pow(2, n_layers).") + elif event_shape[2] % (2 ** n_layers) != 0: + raise ValueError("Image width must be divisible by pow(2, n_layers).") + + +def automatically_determine_n_layers(event_shape): + if event_shape[1] % (2 ** 3) == 0 and event_shape[2] % (2 ** 3) == 0: + # Try using 3 layers + n_layers = 3 + elif event_shape[1] % (2 ** 2) == 0 and event_shape[2] % (2 ** 2) == 0: + # Try using 2 layers + n_layers = 2 + elif event_shape[1] % 2 == 0 and event_shape[2] % 2 == 0: + n_layers = 1 + else: + raise ValueError("Image height and width must be divisible by 2.") + return n_layers def make_factored_image_layers(event_shape, transformer_class, - n_layers: int = 2): + n_layers: int = None): """ Creates a list of image transformations consisting of coupling layers and squeeze layers. After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore). @@ -21,19 +55,10 @@ def make_factored_image_layers(event_shape, :param n_layers: :return: """ - if len(event_shape) != 3: - raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") - if bin(event_shape[1]).count("1") != 1: - raise ValueError("Image height must be a power of two.") - if bin(event_shape[2]).count("1") != 1: - raise ValueError("Image width must be a power of two.") - if n_layers < 1: - raise ValueError - - log_height = math.log2(event_shape[1]) - log_width = math.log2(event_shape[2]) - if n_layers > min(log_height, log_width): - raise ValueError("Too many layers for input image size") + check_image_shape_for_multiscale_flow(event_shape, n_layers) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) def recursive_layer_builder(event_shape_, n_layers_): msb = MultiscaleBijection( @@ -70,16 +95,17 @@ def recursive_layer_builder(event_shape_, n_layers_): def make_image_layers_non_factored(event_shape, transformer_class, - n_layers: int = 2): + n_layers: int = None): """ Returns a list of bijections for transformations of images with multiple channels. - """ - if len(event_shape) != 3: - raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") - assert n_layers >= 1 - - # TODO check that image shape is big enough for this number of layers (divisibility by 2) + Let n be the number of layers. This sequence of bijections takes as input an image with shape (c, h, w) and outputs + an image with shape (4 ** n * c, h / 2 ** n, w / 2 ** n). We require h and w to be divisible by 2 ** n. + """ + check_image_shape_for_multiscale_flow(event_shape, n_layers) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers - 1): @@ -112,7 +138,7 @@ def make_image_layers(*args, factored: bool = False, **kwargs): class MultiscaleRealNVP(BijectiveComposition): def __init__(self, event_shape, - n_layers: int = 3, + n_layers: int = None, factored: bool = False, **kwargs): if isinstance(event_shape, int): @@ -125,7 +151,7 @@ def __init__(self, class MultiscaleNICE(BijectiveComposition): def __init__(self, event_shape, - n_layers: int = 3, + n_layers: int = None, factored: bool = False, **kwargs): if isinstance(event_shape, int): @@ -138,7 +164,7 @@ def __init__(self, class MultiscaleRQNSF(BijectiveComposition): def __init__(self, event_shape, - n_layers: int = 3, + n_layers: int = None, factored: bool = False, **kwargs): if isinstance(event_shape, int): @@ -151,7 +177,7 @@ def __init__(self, class MultiscaleLRSNSF(BijectiveComposition): def __init__(self, event_shape, - n_layers: int = 3, + n_layers: int = None, factored: bool = False, **kwargs): if isinstance(event_shape, int): From 727b4b52d8715b61c78cf19a2acb6df7c4841f51 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 6 Jul 2024 15:09:12 +0200 Subject: [PATCH 47/50] Add sigmoid-based multiscale architectures --- .../finite/multiscale/architectures.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index cc930e0..dd6613d 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -185,3 +185,42 @@ def __init__(self, bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDeepSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DeepSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDeepDenseSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DeepDenseSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleDenseSigmoid(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = None, + factored: bool = False, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, DenseSigmoid, n_layers, factored=factored) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape From d42ac60f9115079f99491f19dfd14b2f99cd8841 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 6 Jul 2024 15:09:27 +0200 Subject: [PATCH 48/50] Add sigmoid-based multiscale architectures to architectures.py --- normalizing_flows/architectures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index b441edf..58921c0 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -30,5 +30,8 @@ MultiscaleRealNVP, MultiscaleRQNSF, MultiscaleLRSNSF, - MultiscaleNICE + MultiscaleNICE, + # MultiscaleDeepSigmoid, # TODO stabler initialization + # MultiscaleDenseSigmoid, # TODO stabler initialization + # MultiscaleDeepDenseSigmoid # TODO stabler initialization ) From 475c0f6dbdcbfcf1bff65c5368a486e3d0cc1a7e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 6 Jul 2024 15:09:46 +0200 Subject: [PATCH 49/50] Add tests for automatically determined number of layers --- test/test_multiscale_bijections.py | 68 +++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py index fa1b181..08bb548 100644 --- a/test/test_multiscale_bijections.py +++ b/test/test_multiscale_bijections.py @@ -1,9 +1,19 @@ -from normalizing_flows.architectures import MultiscaleNICE, MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleRealNVP +from normalizing_flows.architectures import ( + MultiscaleNICE, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleRealNVP +) import torch import pytest -@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) @pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) def test_non_factored(architecture_class, image_shape): torch.manual_seed(0) @@ -15,7 +25,12 @@ def test_non_factored(architecture_class, image_shape): assert torch.allclose(ldf, -ldi, atol=1e-2) -@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) @pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) def test_non_factored_too_small_image(architecture_class, image_shape): torch.manual_seed(0) @@ -23,7 +38,12 @@ def test_non_factored_too_small_image(architecture_class, image_shape): bijection = architecture_class(image_shape, n_layers=3, factored=False) -@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) @pytest.mark.parametrize('image_shape', [(1, 32, 32), (3, 32, 32)]) def test_factored(architecture_class, image_shape): torch.manual_seed(0) @@ -35,7 +55,12 @@ def test_factored(architecture_class, image_shape): assert torch.allclose(ldf, -ldi, atol=1e-2) -@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) @pytest.mark.parametrize('image_shape', [(1, 15, 32), (3, 15, 32)]) def test_factored_wrong_shape(architecture_class, image_shape): torch.manual_seed(0) @@ -44,10 +69,41 @@ def test_factored_wrong_shape(architecture_class, image_shape): bijection = architecture_class(image_shape, n_layers=2, factored=True) -@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) @pytest.mark.parametrize('image_shape', [(1, 8, 8), (3, 8, 8)]) def test_factored_too_small_image(architecture_class, image_shape): torch.manual_seed(0) x = torch.randn(size=(5, *image_shape)) with pytest.raises(ValueError): bijection = architecture_class(image_shape, n_layers=8, factored=True) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 4, 4), (3, 4, 4)]) +def test_non_factored_automatic_n_layers(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, factored=False) + + +@pytest.mark.parametrize('architecture_class', [ + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE, + MultiscaleRealNVP +]) +@pytest.mark.parametrize('image_shape', [(1, 4, 8), (3, 4, 4)]) +def test_factored_automatic_n_layers(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, factored=True) From e12de1d7b662a1f5056b17fe9926a66b02442e3e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 6 Jul 2024 15:45:00 +0200 Subject: [PATCH 50/50] Add FlowMixture to __init__.py --- normalizing_flows/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/normalizing_flows/__init__.py b/normalizing_flows/__init__.py index fc92a0a..1f268b5 100644 --- a/normalizing_flows/__init__.py +++ b/normalizing_flows/__init__.py @@ -1,4 +1,4 @@ -from normalizing_flows.flows import Flow +from normalizing_flows.flows import Flow, FlowMixture from normalizing_flows.bijections import ( NICE, RealNVP,