Skip to content

Commit

Permalink
Add alternative inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 23, 2024
1 parent df85d58 commit eab2fec
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
17 changes: 17 additions & 0 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.

return out, log_det

def inverse2(self, z, context=None):
batch_shape = get_batch_shape(z, self.transformed_event_shape)
log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype)

four_channels, half_height, half_width = z.shape[-3:]
assert four_channels % 4 == 0
width = 2 * half_width
height = 2 * half_height
channels = four_channels // 4

out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype)
out[..., ::2, ::2] = z[..., 0, :, :]
out[..., ::2, 1::2] = z[..., 1, :, :]
out[..., 1::2, ::2] = z[..., 2, :, :]
out[..., 1::2, 1::2] = z[..., 3, :, :]
return out, log_det


class MultiscaleBijection(BijectiveComposition):
def __init__(self,
Expand Down
15 changes: 15 additions & 0 deletions test/test_squeeze_bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ def test_efficient_forward():
z, log_det_forward = layer.forward(x)
z2, log_det_forward2 = layer.forward2(x)
assert torch.allclose(z, z2)


def test_efficient_inverse():
x = torch.tensor([
[1, 2, 5, 6],
[3, 4, 7, 8],
[9, 10, 13, 14],
[11, 12, 15, 16]
])[None, None]
layer = Squeeze(event_shape=x.shape[-3:])
z, log_det_forward = layer.forward(x)

xr, _ = layer.inverse2(z)

assert torch.allclose(x, xr)

0 comments on commit eab2fec

Please sign in to comment.