From ba478476af25b7746952dd5501eb5d70aef0ce81 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 22 Feb 2024 11:08:56 +0100 Subject: [PATCH] Have bijective compositions accept kwargs --- normalizing_flows/bijections/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index a05fe5a..146bff2 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -11,7 +11,8 @@ class Bijection(nn.Module): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): """ Bijection class. """ @@ -93,7 +94,8 @@ class BijectiveComposition(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection], - context_shape: Union[torch.Size, Tuple[int, ...]] = None): + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): super().__init__(event_shape=event_shape, context_shape=context_shape) self.layers = nn.ModuleList(layers)