diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 191a022..90c9561 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -25,3 +25,5 @@ Radial, Sylvester ) + +from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a086705..8cb358d 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -20,6 +20,7 @@ def __init__(self, self.event_shape = event_shape self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) self.context_shape = context_shape + self.transformed_shape = self.event_shape # Overwritten in multiscale flows TODO make into property def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index ef9c7c7..d2ee0ef 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -21,24 +21,6 @@ from normalizing_flows.bijections.finite.linear import ReversePermutation -def make_layers(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False): - if image_coupling: - if len(event_shape) == 2: - bijections = make_image_layers_single_channel(base_bijection, event_shape, n_layers) - elif len(event_shape) == 3: - bijections = make_image_layers_multichannel(base_bijection, event_shape, n_layers) - else: - raise ValueError - else: - bijections = make_basic_layers(base_bijection, event_shape, n_layers, edge_list) - return bijections - - def make_basic_layers(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], event_shape, @@ -56,100 +38,15 @@ def make_basic_layers(base_bijection: Type[ return bijections -def make_image_layers_single_channel(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - checkerboard_resolution: int = 2): - """ - Returns a list of bijections for transformations of images with a single channel. - - Each layer consists of two coupling transforms: - 1. checkerboard, - 2. checkerboard_inverted. - """ - if len(event_shape) != 2: - raise ValueError("Single-channel image transformation are only possible for inputs with two axes.") - - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard_inverted', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(ElementwiseAffine(event_shape=event_shape)) - return bijections - - -def make_image_layers_multichannel(base_bijection: Type[ - Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], - event_shape, - n_layers: int = 2, - checkerboard_resolution: int = 2): - """ - Returns a list of bijections for transformations of images with multiple channels. - - Each layer consists of four coupling transforms: - 1. checkerboard, - 2. channel_wise, - 3. checkerboard_inverted, - 4. channel_wise_inverted. - """ - if len(event_shape) != 3: - raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") - - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'channel_wise' - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'checkerboard_inverted', - 'resolution': checkerboard_resolution, - } - )) - bijections.append(base_bijection( - event_shape=event_shape, - coupling_kwargs={ - 'coupling_type': 'channel_wise_inverted' - } - )) - bijections.append(ElementwiseAffine(event_shape=event_shape)) - return bijections - - class NICE(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(ShiftCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -158,11 +55,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(AffineCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -171,11 +67,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(InverseAffineCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -204,11 +99,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(RQSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -229,11 +123,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(LRSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) @@ -258,11 +151,10 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, - image_coupling: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_layers(DSCoupling, event_shape, n_layers, edge_list, image_coupling) + bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 2c5078e..08b3d61 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -73,51 +73,6 @@ def __init__(self, event_shape): super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) -class Checkerboard(Coupling): - """ - Checkerboard coupling for image data. - """ - - def __init__(self, event_shape, resolution: int = 2, invert: bool = False): - """ - :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal - and a power of two. - :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two - and smaller than image width. - :param invert: invert the checkerboard mask. - """ - height, width = event_shape[-2:] - assert width % resolution == 0 - square_side_length = width // resolution - assert resolution % 2 == 0 - half_resolution = resolution // 2 - a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) - mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) - mask = mask.bool() - if invert: - mask = ~mask - super().__init__(event_shape, mask) - - -class ChannelWiseHalfSplit(Coupling): - """ - Channel-wise coupling for image data. - """ - - def __init__(self, event_shape, invert: bool = False): - """ - :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal - and a power of two. - :param invert: invert the checkerboard mask. - """ - n_channels, height, width = event_shape - mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) - mask = mask[:, None, None].repeat(1, height, width) - if invert: - mask = ~mask - super().__init__(event_shape, mask) - - def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling_type: str = 'half_split', **kwargs): """ @@ -129,16 +84,7 @@ def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling """ if edge_list is not None: return GraphicalCoupling(event_shape, edge_list) + elif coupling_type == 'half_split': + return HalfSplit(event_shape) else: - if coupling_type == 'half_split': - return HalfSplit(event_shape) - elif coupling_type == 'checkerboard': - return Checkerboard(event_shape, invert=False, **kwargs) - elif coupling_type == 'checkerboard_inverted': - return Checkerboard(event_shape, invert=True, **kwargs) - elif coupling_type == 'channel_wise': - return ChannelWiseHalfSplit(event_shape, invert=False) - elif coupling_type == 'channel_wise_inverted': - return ChannelWiseHalfSplit(event_shape, invert=True) - else: - raise ValueError + raise ValueError diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 5c43108..5187c3d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn diff --git a/normalizing_flows/bijections/finite/linear.py b/normalizing_flows/bijections/finite/linear.py index f9c3504..8b98c67 100644 --- a/normalizing_flows/bijections/finite/linear.py +++ b/normalizing_flows/bijections/finite/linear.py @@ -14,89 +14,6 @@ from normalizing_flows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch -class Squeeze(Bijection): - """ - Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape - (*batch_shape, 4 * channels, height / 2, width / 2). - """ - - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - # Check shape length - if len(event_shape) != 3: - raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") - # Check that height and width are divisible by two - if event_shape[1] % 2 != 0: - raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") - if event_shape[2] % 2 != 0: - raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") - super().__init__(event_shape, **kwargs) - c, h, w = event_shape - self.squeezed_event_shape = torch.Size((4 * c, h // 2, w // 2)) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape - (*batch_shape, 4 * channels, height // 2, width // 2). - """ - batch_shape = get_batch_shape(x, self.event_shape) - log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - - channels, height, width = x.shape[-3:] - assert height % 2 == 0 - assert width % 2 == 0 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - - # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) - out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[channel_mask == i] = x[square_mask == i] - - return out, log_det - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape - (*batch_shape, channels, height, width). - """ - batch_shape = get_batch_shape(z, self.squeezed_event_shape) - log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) - - four_channels, half_height, half_width = z.shape[-3:] - assert four_channels % 4 == 0 - width = 2 * half_width - height = 2 * half_height - channels = four_channels // 4 - - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[square_mask == i] = z[channel_mask == i] - - return out, log_det - - class LinearBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix): super().__init__(event_shape) diff --git a/normalizing_flows/bijections/finite/multiscale/__init__.py b/normalizing_flows/bijections/finite/multiscale/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py new file mode 100644 index 0000000..5ecff68 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -0,0 +1,38 @@ +from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from normalizing_flows.bijections import BijectiveComposition +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection + + +def make_image_layers(event_shape, + transformer_class, + n_layers: int = 2): + """ + Returns a list of bijections for transformations of images with multiple channels. + """ + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + + bijections = [ElementwiseAffine(event_shape=event_shape)] + for _ in range(n_layers): + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class + ) + ) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +class MultiscaleRealNVP(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Affine, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py new file mode 100644 index 0000000..b379bd4 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -0,0 +1,152 @@ +from typing import Type, Union, Tuple + +import torch + +from normalizing_flows.bijections import BijectiveComposition, CouplingBijection +from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward +from normalizing_flows.bijections.base import Bijection +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling +from normalizing_flows.utils import get_batch_shape + + +class CheckerboardCoupling(CouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' + ) + transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) + + +class ChannelWiseCoupling(CouplingBijection): + def __init__(self, + event_shape, + transformer_class: Type[TensorTransformer], + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='channel_wise' if not alternate else 'channel_wise_inverted' + ) + transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform, **kwargs) + + +class Squeeze(Bijection): + """ + Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape + (*batch_shape, 4 * channels, height / 2, width / 2). + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + # Check shape length + if len(event_shape) != 3: + raise ValueError(f"Event shape must have three components, but got {len(event_shape)}") + # Check that height and width are divisible by two + if event_shape[1] % 2 != 0: + raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}") + if event_shape[2] % 2 != 0: + raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}") + super().__init__(event_shape, **kwargs) + c, h, w = event_shape + self.transformed_event_shape = torch.Size((4 * c, h // 2, w // 2)) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape + (*batch_shape, 4 * channels, height // 2, width // 2). + """ + batch_shape = get_batch_shape(x, self.event_shape) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + + channels, height, width = x.shape[-3:] + assert height % 2 == 0 + assert width % 2 == 0 + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + + # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) + out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[channel_mask == i] = x[square_mask == i] + + return out, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape + (*batch_shape, channels, height, width). + """ + batch_shape = get_batch_shape(z, self.transformed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + n_rows = height // 2 + n_cols = width // 2 + n_squares = n_rows * n_cols + + square_mask = torch.kron( + torch.arange(n_squares).view(n_rows, n_cols), + torch.ones(2, 2) + ) + channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + + channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) + square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) + for i in range(n_squares): + out[square_mask == i] = z[channel_mask == i] + + return out, log_det + + +class MultiscaleBijection(BijectiveComposition): + def __init__(self, + input_event_shape, + transformer_class: Type[TensorTransformer], + n_checkerboard_layers: int = 3, + n_channel_wise_layers: int = 3, + **kwargs): + checkerboard_layers = [ + CheckerboardCoupling(input_event_shape, transformer_class, alternate=i % 2 == 1) + for i in range(n_checkerboard_layers) + ] + squeeze_layer = Squeeze(input_event_shape) + channel_wise_layers = [ + ChannelWiseCoupling(squeeze_layer.transformed_event_shape, transformer_class, alternate=i % 2 == 1) + for i in range(n_channel_wise_layers) + ] + layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] + super().__init__(input_event_shape, layers, **kwargs) + self.transformed_shape = squeeze_layer.transformed_event_shape diff --git a/normalizing_flows/bijections/finite/multiscale/coupling.py b/normalizing_flows/bijections/finite/multiscale/coupling.py new file mode 100644 index 0000000..c20a323 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/coupling.py @@ -0,0 +1,68 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling + + +class Checkerboard(Coupling): + """ + Checkerboard coupling for image data. + """ + + def __init__(self, event_shape, resolution: int = 2, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two + and smaller than image width. + :param invert: invert the checkerboard mask. + """ + channels, height, width = event_shape[-3:] + assert width % resolution == 0 + square_side_length = width // resolution + assert resolution % 2 == 0 + half_resolution = resolution // 2 + a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution) + mask = torch.kron(a, torch.ones((square_side_length, square_side_length))) + mask = mask.bool() + mask = mask[None].repeat(channels, 1, 1) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + +class ChannelWiseHalfSplit(Coupling): + """ + Channel-wise coupling for image data. + """ + + def __init__(self, event_shape, invert: bool = False): + """ + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal + and a power of two. + :param invert: invert the checkerboard mask. + """ + n_channels, height, width = event_shape + mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) + mask = mask[:, None, None].repeat(1, height, width) + if invert: + mask = ~mask + super().__init__(event_shape, mask) + + +def make_image_coupling(event_shape, coupling_type: str, **kwargs): + """ + + :param event_shape: + :param coupling_type: one of ['checkerboard', 'checkerboard_inverted', 'channel_wise', 'channel_wise_inverted']. + :return: + """ + if coupling_type == 'checkerboard': + return Checkerboard(event_shape, invert=False, **kwargs) + elif coupling_type == 'checkerboard_inverted': + return Checkerboard(event_shape, invert=True, **kwargs) + elif coupling_type == 'channel_wise': + return ChannelWiseHalfSplit(event_shape, invert=False) + elif coupling_type == 'channel_wise_inverted': + return ChannelWiseHalfSplit(event_shape, invert=True) + else: + raise ValueError diff --git a/normalizing_flows/bijections/finite/multiscale/layers.py b/normalizing_flows/bijections/finite/multiscale/layers.py new file mode 100644 index 0000000..7a95207 --- /dev/null +++ b/normalizing_flows/bijections/finite/multiscale/layers.py @@ -0,0 +1,7 @@ +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift + + +class MultiscaleAffineCoupling(MultiscaleBijection): + def __init__(self, input_event_shape, **kwargs): + super().__init__(input_event_shape, transformer_class=Affine, **kwargs) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 465047c..602801b 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -275,18 +275,22 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re :param no_grad: if True, do not track gradients in the inverse pass. :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. """ + if context is not None: - z = self.base_sample(sample_shape=torch.Size((n, len(context)))) + sample_shape = torch.Size((n, len(context))) + z = self.base_sample(sample_shape=sample_shape) context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape assert z.shape[:2] == context.shape[:2] else: - z = self.base_sample(sample_shape=torch.Size((n,))) + sample_shape = torch.Size((n,)) + z = self.base_sample(sample_shape=sample_shape) + if no_grad: z = z.detach() with torch.no_grad(): - x, log_det = self.bijection.inverse(z, context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) else: - x, log_det = self.bijection.inverse(z, context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) x = x.to(self.get_device()) if return_log_prob: diff --git a/test/test_checkerboard_coupling.py b/test/test_checkerboard_coupling.py index 1f55fdf..2f120c0 100644 --- a/test/test_checkerboard_coupling.py +++ b/test/test_checkerboard_coupling.py @@ -1,6 +1,8 @@ -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Checkerboard + import torch +from normalizing_flows.bijections.finite.multiscale.coupling import Checkerboard + def test_checkerboard_small(): torch.manual_seed(0) diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index b5db0b5..6952f34 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -1,6 +1,7 @@ import torch import pytest -from normalizing_flows.bijections.finite.linear import Squeeze + +from normalizing_flows.bijections.finite.multiscale.base import Squeeze @pytest.mark.parametrize('batch_shape', [(1,), (2,), (2, 3)])