Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 19, 2024
1 parent 910adaf commit 5aefcd6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size, int], **kwarg
"""
:param event_shape: shape of the event tensor.
:param kwargs: keyword arguments to AffineCoupling.
:param kwargs: keyword arguments to :class:`~bijections.finite.autoregressive.layers.AffineCoupling`.
"""
super().__init__(event_shape, base_bijection=AffineCoupling, **kwargs)

Expand Down
30 changes: 20 additions & 10 deletions torchflows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ class CouplingBijection(AutoregressiveBijection):
"""
Base coupling bijection object.
A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner.
A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner (specifying how to partition the input tensor).
The coupling conditioner receives as input an event tensor x.
It then partitions an input event tensor x into a constant part x_A and a modifiable part x_B.
For x_A, the conditioner outputs a set of parameters which is always the same.
For x_B, the conditioner outputs a set of parameters which are predicted from x_A.
The coupling conditioner receives as input an event tensor :math:`x`.
It then partitions an input event tensor x into a constant part :math:`x_A` and a modifiable part :math:`x_B`.
For :math:`x_A`, the conditioner outputs a set of parameters which is always the same.
For :math:`x_B`, the conditioner outputs a set of parameters which are predicted from :math:`x_A`.
Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is :math:`x_A` and the second half is :math:`x_B`. When using this in a normalizing flow, permutation layers can shuffle event dimensions.
Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is x_A
and the second half is x_B. When using this in a normalizing flow, permutation layers can shuffle event dimensions.
For improved performance, this implementation does not use a standalone coupling conditioner. It instead implements
a method to partition x into x_A and x_B and then predict parameters for x_B.
For improved performance, this implementation does not use a standalone coupling conditioner, but implements a method to partition x into :math:`x_A` and :math:`x_B` and then predict parameters for :math:`x_B`.
"""

def __init__(self,
Expand All @@ -63,6 +60,19 @@ def __init__(self,
conditioner_kwargs: dict = None,
transformer_kwargs: dict = None,
**kwargs):
"""
CouplingBijection constructor.
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Type[TensorTransformer] transformer_class: transformer class.
:param Union[Tuple[int, ...], torch.Size] context_shape:
:param PartialCoupling coupling:
:param Type[ConditionerTransform] conditioner_transform_class:
:param Dict coupling_kwargs:
:param Dict conditioner_kwargs:
:param Dict transformer_kwargs:
:param kwargs:
"""
coupling_kwargs = coupling_kwargs or {}
conditioner_kwargs = conditioner_kwargs or {}
transformer_kwargs = transformer_kwargs or {}
Expand Down

0 comments on commit 5aefcd6

Please sign in to comment.