diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index 1f53a97..64eb863 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -91,7 +91,8 @@ def invert(bijection: Bijection) -> Bijection: class BijectiveComposition(Bijection): def __init__(self, - event_shape: torch.Size, layers: List[Bijection], + event_shape: Union[torch.Size, Tuple[int, ...]], + layers: List[Bijection], context_shape: Union[torch.Size, Tuple[int, ...]] = None): super().__init__(event_shape=event_shape, context_shape=context_shape) self.layers = nn.ModuleList(layers)