From 5f81473839842f97ab99cfe2394ad5c8f6b1bcda Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 6 Jun 2024 16:22:09 +0200 Subject: [PATCH] Add source and target shape to checkerboard mask --- .../finite/autoregressive/conditioning/coupling_masks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 603a3de..1a163c4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -80,13 +80,13 @@ class Checkerboard(Coupling): def __init__(self, event_shape, resolution: int = 2, invert: bool = False): """ - :param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two and smaller than image width. :param invert: invert the checkerboard mask. """ - n_channels, width, _ = event_shape + n_channels, height, width = event_shape assert width % resolution == 0 square_side_length = width // resolution assert resolution % 2 == 0 @@ -96,6 +96,8 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): mask = mask.bool() if invert: mask = ~mask + self.source_shape = (n_channels, height // resolution, width // resolution) + self.target_shape = (n_channels, height // resolution, width // resolution) super().__init__(event_shape, mask)