Skip to content

Commit

Permalink
Fix inverse and forward in squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 23, 2024
1 parent eab2fec commit 39586c8
Showing 1 changed file with 5 additions and 60 deletions.
65 changes: 5 additions & 60 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,33 +75,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
batch_shape = get_batch_shape(x, self.event_shape)
log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype)

channels, height, width = x.shape[-3:]
assert height % 2 == 0
assert width % 2 == 0
n_rows = height // 2
n_cols = width // 2
n_squares = n_rows * n_cols

square_mask = torch.kron(
torch.arange(n_squares).view(n_rows, n_cols),
torch.ones(2, 2)
)
channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1)

# out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype)
out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype)

channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1)
square_mask = square_mask.repeat(*batch_shape, channels, 1, 1)
for i in range(n_squares):
out[channel_mask == i] = x[square_mask == i]

return out, log_det

def forward2(self, x, context=None):
batch_shape = get_batch_shape(x, self.event_shape)
log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype)

channels, height, width = x.shape[-3:]
assert height % 2 == 0
assert width % 2 == 0
Expand All @@ -111,7 +84,7 @@ def forward2(self, x, context=None):
x[..., ::2, 1::2],
x[..., 1::2, ::2],
x[..., 1::2, 1::2]
], dim=1)
], dim=-3)
return out, log_det

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -128,39 +101,11 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
height = 2 * half_height
channels = four_channels // 4

n_rows = height // 2
n_cols = width // 2
n_squares = n_rows * n_cols

square_mask = torch.kron(
torch.arange(n_squares).view(n_rows, n_cols),
torch.ones(2, 2)
)
channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1)
out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype)

channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1)
square_mask = square_mask.repeat(*batch_shape, channels, 1, 1)
for i in range(n_squares):
out[square_mask == i] = z[channel_mask == i]

return out, log_det

def inverse2(self, z, context=None):
batch_shape = get_batch_shape(z, self.transformed_event_shape)
log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype)

four_channels, half_height, half_width = z.shape[-3:]
assert four_channels % 4 == 0
width = 2 * half_width
height = 2 * half_height
channels = four_channels // 4

out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype)
out[..., ::2, ::2] = z[..., 0, :, :]
out[..., ::2, 1::2] = z[..., 1, :, :]
out[..., 1::2, ::2] = z[..., 2, :, :]
out[..., 1::2, 1::2] = z[..., 3, :, :]
out[..., ::2, ::2] = z[..., 0:channels, :, :]
out[..., ::2, 1::2] = z[..., channels:2 * channels, :, :]
out[..., 1::2, ::2] = z[..., 2 * channels:3 * channels, :, :]
out[..., 1::2, 1::2] = z[..., 3 * channels:4 * channels, :, :]
return out, log_det


Expand Down

0 comments on commit 39586c8

Please sign in to comment.