diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a05fe5a..146bff2 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -11,7 +11,8 @@ class Bijection(nn.Module): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): """ Bijection class. """ @@ -93,7 +94,8 @@ class BijectiveComposition(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): super().__init__(event_shape=event_shape, context_shape=context_shape) self.layers = nn.ModuleList(layers)