diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 603a3de..1a163c4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -80,13 +80,13 @@ class Checkerboard(Coupling): def __init__(self, event_shape, resolution: int = 2, invert: bool = False): """ - :param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal + :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. :param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two and smaller than image width. :param invert: invert the checkerboard mask. """ - n_channels, width, _ = event_shape + n_channels, height, width = event_shape assert width % resolution == 0 square_side_length = width // resolution assert resolution % 2 == 0 @@ -96,6 +96,8 @@ def __init__(self, event_shape, resolution: int = 2, invert: bool = False): mask = mask.bool() if invert: mask = ~mask + self.source_shape = (n_channels, height // resolution, width // resolution) + self.target_shape = (n_channels, height // resolution, width // resolution) super().__init__(event_shape, mask)