From eab2fece7ea6fd6647ff581b82f0a9bf015629f4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 23 Jun 2024 19:21:14 +0200 Subject: [PATCH] Add alternative inverse --- .../bijections/finite/multiscale/base.py | 17 +++++++++++++++++ test/test_squeeze_bijection.py | 15 +++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index 2258c47..f073c62 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -146,6 +146,23 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return out, log_det + def inverse2(self, z, context=None): + batch_shape = get_batch_shape(z, self.transformed_event_shape) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + + four_channels, half_height, half_width = z.shape[-3:] + assert four_channels % 4 == 0 + width = 2 * half_width + height = 2 * half_height + channels = four_channels // 4 + + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + out[..., ::2, ::2] = z[..., 0, :, :] + out[..., ::2, 1::2] = z[..., 1, :, :] + out[..., 1::2, ::2] = z[..., 2, :, :] + out[..., 1::2, 1::2] = z[..., 3, :, :] + return out, log_det + class MultiscaleBijection(BijectiveComposition): def __init__(self, diff --git a/test/test_squeeze_bijection.py b/test/test_squeeze_bijection.py index a92eb87..86a3e39 100644 --- a/test/test_squeeze_bijection.py +++ b/test/test_squeeze_bijection.py @@ -32,3 +32,18 @@ def test_efficient_forward(): z, log_det_forward = layer.forward(x) z2, log_det_forward2 = layer.forward2(x) assert torch.allclose(z, z2) + + +def test_efficient_inverse(): + x = torch.tensor([ + [1, 2, 5, 6], + [3, 4, 7, 8], + [9, 10, 13, 14], + [11, 12, 15, 16] + ])[None, None] + layer = Squeeze(event_shape=x.shape[-3:]) + z, log_det_forward = layer.forward(x) + + xr, _ = layer.inverse2(z) + + assert torch.allclose(x, xr)