diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index a45140d..66046b0 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -16,7 +16,9 @@ LRSForwardMaskedAutoregressive, LRSInverseMaskedAutoregressive, DenseSigmoidalCoupling, - DeepDenseSigmoidalCoupling + DeepDenseSigmoidalCoupling, DeepSigmoidalInverseMaskedAutoregressive, DeepSigmoidalForwardMaskedAutoregressive, + DenseSigmoidalInverseMaskedAutoregressive, DenseSigmoidalForwardMaskedAutoregressive, + DeepDenseSigmoidalInverseMaskedAutoregressive, DeepDenseSigmoidalForwardMaskedAutoregressive ) from torchflows.bijections.base import BijectiveComposition from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection, \ @@ -221,6 +223,40 @@ def __init__(self, super().__init__(event_shape, bijections, **kwargs) +class InverseAutoregressiveDeepSF(BijectiveComposition): + """Inverse autoregressive deep sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDeepSF(BijectiveComposition): + """Masked autoregressive deep sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + class CouplingDenseSF(BijectiveComposition): """Coupling dense sigmoidal flow architecture. @@ -238,6 +274,40 @@ def __init__(self, super().__init__(event_shape, bijections, **kwargs) +class InverseAutoregressiveDenseSF(BijectiveComposition): + """Inverse autoregressive dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDenseSF(BijectiveComposition): + """Masked autoregressive dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + class CouplingDeepDenseSF(BijectiveComposition): """Coupling deep-dense sigmoidal flow architecture. @@ -255,6 +325,40 @@ def __init__(self, super().__init__(event_shape, bijections, **kwargs) +class InverseAutoregressiveDeepDenseSF(BijectiveComposition): + """Inverse autoregressive deep-dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepDenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDeepDenseSF(BijectiveComposition): + """Masked autoregressive deep-dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepDenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + class UMNNMAF(BijectiveComposition): """Unconstrained monotonic neural network masked autoregressive flow (UMNN-MAF) architecture. diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index 4e1eb22..5d4b173 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -209,6 +209,32 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + +class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + class DenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, @@ -235,6 +261,32 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_dense_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DenseSigmoid( + event_shape=torch.Size(event_shape), + n_dense_layers=n_dense_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + +class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_dense_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DenseSigmoid( + event_shape=torch.Size(event_shape), + n_dense_layers=n_dense_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + class DeepDenseSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, @@ -261,6 +313,32 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepDenseSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + +class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepDenseSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + class LinearAffineCoupling(AffineCoupling): def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape, **kwargs, n_layers=1)