diff --git a/torchflows/bijections/finite/multiscale/base.py b/torchflows/bijections/finite/multiscale/base.py index 6d70259..b88a9d1 100644 --- a/torchflows/bijections/finite/multiscale/base.py +++ b/torchflows/bijections/finite/multiscale/base.py @@ -231,14 +231,20 @@ def __init__(self, Type[NormalizedChannelWiseCoupling], Type[GlowChannelWiseCoupling] ] = NormalizedChannelWiseCoupling, + first_layer: bool = True, **kwargs): if n_blocks < 1: raise ValueError super().__init__(event_shape, **kwargs) self.n_blocks = n_blocks + + if first_layer and checkerboard_class == GlowCheckerboardCoupling: + layer_checkerboard_class = NormalizedCheckerboardCoupling # Compatibility with single channel images + else: + layer_checkerboard_class = checkerboard_class self.checkerboard_layers = nn.ModuleList([ - checkerboard_class( + layer_checkerboard_class( event_shape, transformer_class=transformer_class, alternate=i % 2 == 1, @@ -269,6 +275,7 @@ def __init__(self, event_shape=small_event_shape, transformer_class=transformer_class, n_blocks=self.n_blocks - 1, + first_layer=False, **kwargs )