diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index f073c62..fad40f7 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -75,33 +75,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. batch_shape = get_batch_shape(x, self.event_shape) log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - channels, height, width = x.shape[-3:] - assert height % 2 == 0 - assert width % 2 == 0 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - - # out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype) - out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[channel_mask == i] = x[square_mask == i] - - return out, log_det - - def forward2(self, x, context=None): - batch_shape = get_batch_shape(x, self.event_shape) - log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) - channels, height, width = x.shape[-3:] assert height % 2 == 0 assert width % 2 == 0 @@ -111,7 +84,7 @@ def forward2(self, x, context=None): x[..., ::2, 1::2], x[..., 1::2, ::2], x[..., 1::2, 1::2] - ], dim=1) + ], dim=-3) return out, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -128,39 +101,11 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. height = 2 * half_height channels = four_channels // 4 - n_rows = height // 2 - n_cols = width // 2 - n_squares = n_rows * n_cols - - square_mask = torch.kron( - torch.arange(n_squares).view(n_rows, n_cols), - torch.ones(2, 2) - ) - channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1) - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - - channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1) - square_mask = square_mask.repeat(*batch_shape, channels, 1, 1) - for i in range(n_squares): - out[square_mask == i] = z[channel_mask == i] - - return out, log_det - - def inverse2(self, z, context=None): - batch_shape = get_batch_shape(z, self.transformed_event_shape) - log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) - - four_channels, half_height, half_width = z.shape[-3:] - assert four_channels % 4 == 0 - width = 2 * half_width - height = 2 * half_height - channels = four_channels // 4 - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) - out[..., ::2, ::2] = z[..., 0, :, :] - out[..., ::2, 1::2] = z[..., 1, :, :] - out[..., 1::2, ::2] = z[..., 2, :, :] - out[..., 1::2, 1::2] = z[..., 3, :, :] + out[..., ::2, ::2] = z[..., 0:channels, :, :] + out[..., ::2, 1::2] = z[..., channels:2 * channels, :, :] + out[..., 1::2, ::2] = z[..., 2 * channels:3 * channels, :, :] + out[..., 1::2, 1::2] = z[..., 3 * channels:4 * channels, :, :] return out, log_det