diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 8f5a49e..6bfabb6 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -1,14 +1,76 @@ +import torch + from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational from normalizing_flows.bijections import BijectiveComposition -from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection +from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection, FactoredBijection +import math + + +def make_factored_image_layers(event_shape, + transformer_class, + n_layers: int = 2): + """ + Creates a list of image transformations consisting of coupling layers and squeeze layers. + After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore). + + :param event_shape: (c, 2^n, 2^m). + :param transformer_class: + :param n_layers: + :return: + """ + if len(event_shape) != 3: + raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + if bin(event_shape[1]).count("1") != 1: + raise ValueError("Image height must be a power of two.") + if bin(event_shape[2]).count("1") != 1: + raise ValueError("Image width must be a power of two.") + if n_layers < 1: + raise ValueError + + log_height = math.log2(event_shape[1]) + log_width = math.log2(event_shape[2]) + if n_layers > min(log_height, log_width): + raise ValueError("Too many layers for input image size") + + def recursive_layer_builder(event_shape_, n_layers_): + msb = MultiscaleBijection( + input_event_shape=event_shape_, + transformer_class=transformer_class + ) + if n_layers_ == 1: + return msb + c, h, w = msb.transformed_shape # c is a multiple of 4 after squeezing -def make_image_layers(event_shape, - transformer_class, - n_layers: int = 2): + small_bijection_shape = (c // 2, h, w) + small_bijection_mask = (torch.arange(c) >= c // 2)[:, None, None].repeat(1, h, w) + fb = FactoredBijection( + event_shape=(c, h, w), + small_bijection=recursive_layer_builder( + event_shape_=small_bijection_shape, + n_layers_=n_layers_ - 1 + ), + small_bijection_mask=small_bijection_mask + ) + composition = BijectiveComposition( + event_shape=msb.event_shape, + layers=[msb, fb] + ) + composition.transformed_shape = fb.transformed_shape + return composition + + bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections.append(recursive_layer_builder(bijections[-1].transformed_shape, n_layers)) + bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) + return bijections + + +def make_image_layers_non_factored(event_shape, + transformer_class, + n_layers: int = 2): """ Returns a list of bijections for transformations of images with multiple channels. """ @@ -17,6 +79,8 @@ def make_image_layers(event_shape, assert n_layers >= 1 + # TODO check that image shape is big enough for this number of layers (divisibility by 2) + bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers - 1): bijections.append( @@ -38,14 +102,22 @@ def make_image_layers(event_shape, return bijections +def make_image_layers(*args, factored: bool = False, **kwargs): + if factored: + return make_factored_image_layers(*args, **kwargs) + else: + return make_image_layers_non_factored(*args, **kwargs) + + class MultiscaleRealNVP(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Affine, n_layers) + bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -54,10 +126,11 @@ class MultiscaleNICE(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Shift, n_layers) + bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -66,10 +139,11 @@ class MultiscaleRQNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, RationalQuadratic, n_layers) + bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape @@ -78,9 +152,10 @@ class MultiscaleLRSNSF(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 3, + factored: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, LinearRational, n_layers) + bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index a3a8d21..4298710 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -23,17 +23,14 @@ class FactoredBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - transformed_event_shape: Union[torch.Size, Tuple[int, ...]], small_bijection: Bijection, - transformed_event_mask: torch.Tensor, + small_bijection_mask: torch.Tensor, **kwargs): """ :param event_shape: shape of input event x. - :param transformed_event_shape: shape of transformed event x_A. - :param constant_event_shape: shape of constant event x_B. :param small_bijection: bijection applied to transformed event x_A. - :param transformed_event_mask: boolean mask that selects which elements of event x correspond to the transformed + :param small_bijection_mask: boolean mask that selects which elements of event x correspond to the transformed event x_A. :param kwargs: """ @@ -41,20 +38,18 @@ def __init__(self, # Check that shapes are correct event_size = torch.prod(torch.as_tensor(event_shape)) - transformed_event_size = torch.prod(torch.as_tensor(transformed_event_shape)) + transformed_event_size = torch.prod(torch.as_tensor(small_bijection.event_shape)) assert event_size >= transformed_event_size - assert transformed_event_mask.shape == event_shape - assert small_bijection.event_shape == transformed_event_shape + assert small_bijection_mask.shape == event_shape - self.transformed_event_mask = transformed_event_mask - self.transformed_event_shape = transformed_event_shape + self.transformed_event_mask = small_bijection_mask self.small_bijection = small_bijection def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(x, self.event_shape) transformed, log_det = self.small_bijection.forward( - x[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + x[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.event_shape), context ) out = x.clone() @@ -64,7 +59,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) transformed, log_det = self.small_bijection.inverse( - z[..., self.transformed_event_mask].view(*batch_shape, *self.transformed_event_shape), + z[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.transformed_shape), context ) out = z.clone() diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py index 6df7726..4c43cf7 100644 --- a/test/test_factored_bijection.py +++ b/test/test_factored_bijection.py @@ -8,8 +8,8 @@ def test_basic(): bijection = FactoredBijection( event_shape=(6, 6), - transformed_event_shape=(3, 3), - transformed_event_mask=torch.tensor([ + small_bijection_event_shape=(3, 3), + small_bijection_mask=torch.tensor([ [True, True, True, False, False, False], [True, True, True, False, False, False], [True, True, True, False, False, False], diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py new file mode 100644 index 0000000..fa1b181 --- /dev/null +++ b/test/test_multiscale_bijections.py @@ -0,0 +1,53 @@ +from normalizing_flows.architectures import MultiscaleNICE, MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleRealNVP +import torch +import pytest + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=False) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) +def test_non_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=3, factored=False) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 32, 32), (3, 32, 32)]) +def test_factored(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + bijection = architecture_class(image_shape, n_layers=2, factored=True) + z, ldf = bijection.forward(x) + xr, ldi = bijection.inverse(z) + assert torch.allclose(x, xr, atol=1e-4) + assert torch.allclose(ldf, -ldi, atol=1e-2) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 15, 32), (3, 15, 32)]) +def test_factored_wrong_shape(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=2, factored=True) + + +@pytest.mark.parametrize('architecture_class', [MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, MultiscaleRealNVP]) +@pytest.mark.parametrize('image_shape', [(1, 8, 8), (3, 8, 8)]) +def test_factored_too_small_image(architecture_class, image_shape): + torch.manual_seed(0) + x = torch.randn(size=(5, *image_shape)) + with pytest.raises(ValueError): + bijection = architecture_class(image_shape, n_layers=8, factored=True)