Skip to content

Commit

Permalink
Add faster squeeze method
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 23, 2024
1 parent 416525e commit df85d58
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
16 changes: 16 additions & 0 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.

return out, log_det

def forward2(self, x, context=None):
batch_shape = get_batch_shape(x, self.event_shape)
log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype)

channels, height, width = x.shape[-3:]
assert height % 2 == 0
assert width % 2 == 0

out = torch.concatenate([
x[..., ::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, ::2],
x[..., 1::2, 1::2]
], dim=1)
return out, log_det

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape
Expand Down
13 changes: 13 additions & 0 deletions test/test_squeeze_bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,16 @@ def test_reconstruction(batch_shape, channels, height, width):
assert torch.allclose(x_reconstructed, x)
assert torch.allclose(log_det_forward, torch.zeros_like(log_det_forward))
assert torch.allclose(log_det_forward, log_det_inverse)


def test_efficient_forward():
x = torch.tensor([
[1, 2, 5, 6],
[3, 4, 7, 8],
[9, 10, 13, 14],
[11, 12, 15, 16]
])[None, None]
layer = Squeeze(event_shape=x.shape[-3:])
z, log_det_forward = layer.forward(x)
z2, log_det_forward2 = layer.forward2(x)
assert torch.allclose(z, z2)

0 comments on commit df85d58

Please sign in to comment.