diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 5ecff68..21f7d8f 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -13,14 +13,25 @@ def make_image_layers(event_shape, if len(event_shape) != 3: raise ValueError("Multichannel image transformation are only possible for inputs with three axes.") + assert n_layers >= 1 + bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers): + for _ in range(n_layers - 1): bijections.append( MultiscaleBijection( input_event_shape=bijections[-1].transformed_shape, transformer_class=transformer_class ) ) + bijections.append( + MultiscaleBijection( + input_event_shape=bijections[-1].transformed_shape, + transformer_class=transformer_class, + n_checkerboard_layers=4, + squeeze_layer=False, + n_channel_wise_layers=0 + ) + ) bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) return bijections @@ -35,4 +46,3 @@ def __init__(self, bijections = make_image_layers(event_shape, Affine, n_layers) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape -