diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 8975685..97cb0da 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -78,7 +78,7 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwarg """ :param event_shape: shape of the event tensor. - :param kwargs: keyword arguments to AffineCoupling. + :param kwargs: keyword arguments to :class:`~bijections.finite.autoregressive.layers.AffineCoupling`. """ super().__init__(event_shape, base_bijection=AffineCoupling, **kwargs) diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index 9a47270..0c4afcd 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -39,18 +39,15 @@ class CouplingBijection(AutoregressiveBijection): """ Base coupling bijection object. - A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner. + A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner (specifying how to partition the input tensor). - The coupling conditioner receives as input an event tensor x. - It then partitions an input event tensor x into a constant part x_A and a modifiable part x_B. - For x_A, the conditioner outputs a set of parameters which is always the same. - For x_B, the conditioner outputs a set of parameters which are predicted from x_A. + The coupling conditioner receives as input an event tensor :math:`x`. + It then partitions an input event tensor x into a constant part :math:`x_A` and a modifiable part :math:`x_B`. + For :math:`x_A`, the conditioner outputs a set of parameters which is always the same. + For :math:`x_B`, the conditioner outputs a set of parameters which are predicted from :math:`x_A`. + Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is :math:`x_A` and the second half is :math:`x_B`. When using this in a normalizing flow, permutation layers can shuffle event dimensions. - Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is x_A - and the second half is x_B. When using this in a normalizing flow, permutation layers can shuffle event dimensions. - - For improved performance, this implementation does not use a standalone coupling conditioner. It instead implements - a method to partition x into x_A and x_B and then predict parameters for x_B. + For improved performance, this implementation does not use a standalone coupling conditioner, but implements a method to partition x into :math:`x_A` and :math:`x_B` and then predict parameters for :math:`x_B`. """ def __init__(self, @@ -63,6 +60,19 @@ def __init__(self, conditioner_kwargs: dict = None, transformer_kwargs: dict = None, **kwargs): + """ + CouplingBijection constructor. + + :param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor. + :param Type[TensorTransformer] transformer_class: transformer class. + :param Union[Tuple[int, ...], torch.Size] context_shape: + :param PartialCoupling coupling: + :param Type[ConditionerTransform] conditioner_transform_class: + :param Dict coupling_kwargs: + :param Dict conditioner_kwargs: + :param Dict transformer_kwargs: + :param kwargs: + """ coupling_kwargs = coupling_kwargs or {} conditioner_kwargs = conditioner_kwargs or {} transformer_kwargs = transformer_kwargs or {}