diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index 6e2133a..be33e97 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -21,10 +21,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to Affine. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Affine(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, Affine, **kwargs) class ElementwiseInverseAffine(ElementwiseBijection): @@ -32,10 +31,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to InverseAffine. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = InverseAffine(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, InverseAffine, **kwargs) class ActNorm(ElementwiseInverseAffine): @@ -75,20 +73,19 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to Scale. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Scale(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, Scale, **kwargs) class ElementwiseShift(ElementwiseBijection): - def __init__(self, event_shape): + def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = Shift(event_shape) - super().__init__(transformer) + super().__init__(event_shape, Shift, **kwargs) class ElementwiseRQSpline(ElementwiseBijection): @@ -96,10 +93,9 @@ def __init__(self, event_shape, **kwargs): """ :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. - :param kwargs: keyword arguments to RationalQuadratic. + :param kwargs: keyword arguments to ElementwiseBijection. """ - transformer = RationalQuadratic(event_shape, **kwargs) - super().__init__(transformer) + super().__init__(event_shape, RationalQuadratic, **kwargs) class AffineCoupling(CouplingBijection): diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index 4e7cdc4..9a47270 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Type +from typing import Tuple, Union, Type, Optional import torch import torch.nn as nn @@ -15,7 +15,7 @@ class AutoregressiveBijection(Bijection): def __init__(self, event_shape, transformer: Union[TensorTransformer, ScalarTransformer], - conditioner_transform: ConditionerTransform, + conditioner_transform: Optional[ConditionerTransform], **kwargs): super().__init__(event_shape=event_shape) self.conditioner_transform = conditioner_transform @@ -205,11 +205,19 @@ class ElementwiseBijection(AutoregressiveBijection): The bijection for each element has its own set of globally learned parameters. """ - def __init__(self, transformer: ScalarTransformer, fill_value: float = None): + def __init__(self, + event_shape: Union[Tuple[int, ...], torch.Size], + transformer_class: Type[ScalarTransformer], + transformer_kwargs: dict = None, + fill_value: float = None, + **kwargs): + transformer_kwargs = transformer_kwargs or {} + transformer = transformer_class(event_shape=event_shape, **transformer_kwargs) super().__init__( - transformer.event_shape, + event_shape, transformer, - None + None, + **kwargs ) if fill_value is None: diff --git a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py index 5c0e9a4..13ab46e 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py +++ b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py @@ -135,7 +135,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch class Shift(ScalarTransformer): - def __init__(self, event_shape: torch.Size): + def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape=event_shape) @property