Skip to content

Commit

Permalink
Masked/inverse autoregressive NF cleanup/refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 1, 2024
1 parent 1888a44 commit bfdfbce
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 166 deletions.
188 changes: 29 additions & 159 deletions torchflows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

import torch

from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling
from torchflows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \
InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection
from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift, InverseAffine
from torchflows.bijections.finite.autoregressive.transformers.base import ScalarTransformer
from torchflows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \
UnconstrainedMonotonicNeuralNetwork
from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational
Expand All @@ -17,7 +14,6 @@
DenseSigmoid,
DeepDenseSigmoid
)
from torchflows.bijections.base import invert


class ElementwiseAffine(ElementwiseBijection):
Expand Down Expand Up @@ -185,41 +181,25 @@ def __init__(self,
class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DeepSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, DeepSigmoid, **kwargs)


class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DeepSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, DeepSigmoid, **kwargs)


class DenseSigmoidalCoupling(CouplingBijection):
Expand All @@ -241,64 +221,33 @@ def __init__(self,
class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
percentage_global_parameters: float = 0.8,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of
being predicted from the conditioner neural network.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DenseSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(
event_shape,
context_shape,
transformer=transformer,
**{
**kwargs,
**dict(percentage_global_parameters=percentage_global_parameters)
}
)
if 'conditioner_kwargs' not in kwargs:
kwargs['conditioner_kwargs'] = {}
if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']:
kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8
super().__init__(event_shape, DenseSigmoid, **kwargs)


class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
percentage_global_parameters: float = 0.8,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_transformer_layers: number of transformer layers.
:param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of
being predicted from the conditioner neural network.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DenseSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(
event_shape,
context_shape,
transformer=transformer,
**{
**kwargs,
**dict(percentage_global_parameters=percentage_global_parameters)
}
)
if 'conditioner_kwargs' not in kwargs:
kwargs['conditioner_kwargs'] = {}
if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']:
kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8
super().__init__(event_shape, DenseSigmoid, **kwargs)


class DeepDenseSigmoidalCoupling(CouplingBijection):
Expand All @@ -320,69 +269,33 @@ def __init__(self,
class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
percentage_global_parameters: float = 0.8,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_transformer_hidden_layers: number of transformer hidden layers.
:param int n_transformer_dense_layers: number of transformer dense layers.
:param int transformer_hidden_size: transformer hidden layer size.
:param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of
being predicted from the conditioner neural network.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DeepDenseSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(
event_shape,
context_shape,
transformer=transformer,
**{
**kwargs,
**dict(percentage_global_parameters=percentage_global_parameters)
}
)
if 'conditioner_kwargs' not in kwargs:
kwargs['conditioner_kwargs'] = {}
if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']:
kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8
super().__init__(event_shape, DeepDenseSigmoid, **kwargs)


class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
percentage_global_parameters: float = 0.8,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_transformer_hidden_layers: number of transformer hidden layers.
:param int n_transformer_dense_layers: number of transformer dense layers.
:param int transformer_hidden_size: transformer hidden layer size.
:param float percentage_global_parameters: percentage of transformer inputs to be learned globally instead of
being predicted from the conditioner neural network.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = DeepDenseSigmoid(
event_shape=torch.Size(event_shape),
**transformer_kwargs
)
super().__init__(
event_shape,
context_shape,
transformer=transformer,
**{
**kwargs,
**dict(percentage_global_parameters=percentage_global_parameters)
}
)
if 'conditioner_kwargs' not in kwargs:
kwargs['conditioner_kwargs'] = {}
if 'percentage_global_parameters' not in kwargs['conditioner_kwargs']:
kwargs['conditioner_kwargs']['percentage_global_parameters'] = 0.8
super().__init__(event_shape, DeepDenseSigmoid, **kwargs)


class LinearAffineCoupling(AffineCoupling):
Expand Down Expand Up @@ -428,125 +341,82 @@ def __init__(self, event_shape: Union[Tuple[int, ...], torch.Size], **kwargs):
class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = Affine(event_shape=event_shape, **transformer_kwargs)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, Affine, **kwargs)


class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_bins: number of spline bins.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, RationalQuadratic, **kwargs)


class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_bins: number of spline bins.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, LinearRational, **kwargs)


class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = invert(Affine(event_shape=event_shape, **transformer_kwargs))
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, InverseAffine, **kwargs)


class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param int n_bins: number of spline bins.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, **transformer_kwargs)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, RationalQuadratic, **kwargs)


class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param dict transformer_kwargs: keyword arguments to LinearRational.
:param kwargs: keyword arguments to InverseMaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = LinearRational(event_shape=event_shape, **transformer_kwargs)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, LinearRational, **kwargs)


class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
**kwargs):
"""
:param Union[Tuple[int, ...], torch.Size] event_shape: shape of the event tensor.
:param Union[Tuple[int, ...], torch.Size] context_shape: shape of the context tensor.
:param dict transformer_kwargs: keyword arguments to UnconstrainedMonotonicNeuralNetwork.
:param kwargs: keyword arguments to MaskedAutoregressiveBijection.
"""
transformer_kwargs = transformer_kwargs or {}
transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork(
event_shape=event_shape,
**transformer_kwargs
)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)
super().__init__(event_shape, UnconstrainedMonotonicNeuralNetwork, **kwargs)
15 changes: 10 additions & 5 deletions torchflows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,23 @@ class MaskedAutoregressiveBijection(AutoregressiveBijection):
"""

def __init__(self,
event_shape,
context_shape,
transformer: ScalarTransformer,
event_shape: Union[Tuple[int, ...], torch.Size],
transformer_class: Type[ScalarTransformer],
context_shape: Union[Tuple[int, ...], torch.Size] = None,
transformer_kwargs: dict = None,
conditioner_kwargs: dict = None,
**kwargs):
conditioner_kwargs = conditioner_kwargs or {}
transformer_kwargs = transformer_kwargs or {}
transformer = transformer_class(event_shape=event_shape, **transformer_kwargs)
conditioner_transform = MADE(
input_event_shape=event_shape,
transformed_event_shape=event_shape,
parameter_shape_per_element=transformer.parameter_shape_per_element,
context_shape=context_shape,
**kwargs
**conditioner_kwargs
)
super().__init__(transformer.event_shape, transformer, conditioner_transform)
super().__init__(transformer.event_shape, transformer, conditioner_transform, **kwargs)

def apply_conditioner_transformer(self, inputs, context, forward: bool = True):
h = self.conditioner_transform(inputs, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def default_parameters(self) -> torch.Tensor:


class ScalarTransformer(TensorTransformer):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]):
super().__init__(event_shape)
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape, **kwargs)

@property
def parameter_shape_per_element(self):
Expand Down

0 comments on commit bfdfbce

Please sign in to comment.