diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index a48de57..aa7b844 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -1,7 +1,6 @@ from typing import Type, Union, Tuple import torch -import torch.nn as nn from normalizing_flows.bijections import BijectiveComposition from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform @@ -10,11 +9,30 @@ from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \ ChannelWiseHalfSplit +from normalizing_flows.neural_networks.convnet import ConvNet from normalizing_flows.utils import get_batch_shape -class ResNet(ConditionerTransform): - pass +class ConvNetConditioner(ConditionerTransform): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + kernels: Tuple[int, ...] = None, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=None, + parameter_shape=parameter_shape, + **kwargs + ) + self.network = ConvNet( + input_shape=input_event_shape, + n_outputs=self.n_transformer_parameters, + kernels=kernels, + ) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.network(x) class ConvolutionalCouplingBijection(CouplingBijection): @@ -22,8 +40,12 @@ def __init__(self, transformer: TensorTransformer, coupling: Union[Checkerboard, ChannelWiseHalfSplit], **kwargs): - conditioner_transform = ResNet() - super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs) + conditioner_transform = ConvNetConditioner( + input_event_shape=coupling.constant_shape, + parameter_shape=transformer.parameter_shape, + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) self.coupling = coupling def get_constant_part(self, x: torch.Tensor) -> torch.Tensor: @@ -53,6 +75,12 @@ def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): batch_shape = get_batch_shape(x, self.event_shape) return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): + batch_shape = get_batch_shape(x, self.event_shape) + super_out = super().partition_and_predict_parameters(x, context) + return super_out.view(*batch_shape, *self.coupling.transformed_shape, + *self.transformer.parameter_shape_per_element) + class CheckerboardCoupling(ConvolutionalCouplingBijection): def __init__(self, @@ -64,7 +92,7 @@ def __init__(self, event_shape, coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' ) - transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + transformer = transformer_class(event_shape=coupling.transformed_shape) super().__init__(transformer, coupling, **kwargs) @@ -78,7 +106,7 @@ def __init__(self, event_shape, coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' ) - transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + transformer = transformer_class(event_shape=coupling.transformed_shape) super().__init__(transformer, coupling, **kwargs) @@ -151,14 +179,11 @@ def __init__(self, n_channel_wise_layers: int = 3, use_squeeze_layer: bool = True, **kwargs): - channels, height, width = input_event_shape[-3:] - resolution = min(width, height) // 2 checkerboard_layers = [ CheckerboardCoupling( input_event_shape, transformer_class, - alternate=i % 2 == 1, - resolution=resolution + alternate=i % 2 == 1 ) for i in range(n_checkerboard_layers) ] @@ -167,8 +192,7 @@ def __init__(self, ChannelWiseCoupling( squeeze_layer.transformed_event_shape, transformer_class, - alternate=i % 2 == 1, - resolution=resolution + alternate=i % 2 == 1 ) for i in range(n_channel_wise_layers) ] diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py index 6470086..33bf6cd 100644 --- a/normalizing_flows/bijections/finite/multiscale/coupling.py +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -8,32 +8,23 @@ class Checkerboard(Coupling): Checkerboard coupling for image data. """ - def __init__(self, event_shape, resolution: int = 2, invert: bool = False): + def __init__(self, event_shape, invert: bool = False): """ :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. - :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two - and smaller than image width. :param invert: invert the checkerboard mask. """ channels, height, width = event_shape - assert width % resolution == 0 - square_side_length = width // resolution - assert resolution % 2 == 0 - half_resolution = resolution // 2 - a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) - mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) - mask = mask.bool() + mask = (torch.arange(height * width) % 2).view(height, width).bool() mask = mask[None].repeat(channels, 1, 1) # (channels, height, width) if invert: mask = ~mask - self.resolution = resolution super().__init__(event_shape, mask) @property def constant_shape(self): - n_channels, _, _ = self.event_shape - return n_channels, self.resolution, self.resolution + n_channels, height, width = self.event_shape + return n_channels, height // 2, width # rectangular shape @property def transformed_shape(self): diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py deleted file mode 100644 index 00b1dd4..0000000 --- a/test/test_checkerboard_coupling.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch - -from normalizing_flows.bijections.finite.multiscale.coupling import Checkerboard - - -def test_checkerboard_small(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2) - assert torch.allclose( - coupling.source_mask, - torch.tensor([ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_checkerboard_medium(): - torch.manual_seed(0) - image_shape = (3, 16, 16) - coupling = Checkerboard(image_shape, resolution=4) - assert torch.allclose( - coupling.source_mask, - torch.tensor([ - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_checkerboard_small_inverted(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2, invert=True) - assert torch.allclose( - coupling.source_mask, - ~torch.tensor([ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 1, 1], - ], dtype=torch.bool)[None].repeat(3, 1, 1) - ) - assert torch.allclose(coupling.target_mask, ~coupling.source_mask) - - -def test_partition_shapes_1(): - torch.manual_seed(0) - image_shape = (3, 4, 4) - coupling = Checkerboard(image_shape, resolution=2, invert=True) - assert coupling.constant_shape == (3, 2, 2) - assert coupling.transformed_shape == (3, 2, 2) - - -def test_partition_shapes_2(): - torch.manual_seed(0) - image_shape = (3, 16, 16) - coupling = Checkerboard(image_shape, resolution=8, invert=True) - assert coupling.constant_shape == (3, 8, 8) - assert coupling.transformed_shape == (3, 8, 8) - - -def test_partition_shapes_3(): - torch.manual_seed(0) - image_shape = (3, 16, 8) - coupling = Checkerboard(image_shape, resolution=4, invert=True) - assert coupling.constant_shape == (3, 4, 4) - assert coupling.transformed_shape == (3, 4, 4)