Skip to content

Commit

Permalink
Add more LRS and NAF architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 23, 2024
1 parent e436979 commit 49e265d
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 16 deletions.
4 changes: 2 additions & 2 deletions test/test_identity_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torchflows.bijections.finite.autoregressive.layers import (
AffineCoupling,
DSCoupling,
DeepSigmoidalCoupling,
RQSCoupling,
InverseAffineCoupling,
LRSCoupling,
Expand Down Expand Up @@ -32,7 +32,7 @@
'layer_class',
[
AffineCoupling,
DSCoupling,
DeepSigmoidalCoupling,
RQSCoupling,
InverseAffineCoupling,
LRSCoupling,
Expand Down
12 changes: 6 additions & 6 deletions test/test_sigmoid_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch

from torchflows import Flow
from torchflows.bijections.finite.autoregressive.architectures import CouplingDSF
from torchflows.bijections.finite.autoregressive.layers import DSCoupling
from torchflows.bijections.finite.autoregressive.architectures import CouplingDeepSF
from torchflows.bijections.finite.autoregressive.layers import DeepSigmoidalCoupling
from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import Sigmoid, DeepSigmoid
from torchflows.bijections.base import invert
from test.constants import __test_constants
Expand Down Expand Up @@ -77,8 +77,8 @@ def test_deep_sigmoid_transformer(event_shape, batch_shape, hidden_dim):
def test_deep_sigmoid_coupling(event_shape, batch_shape):
torch.manual_seed(0)

forward_layer = DSCoupling(torch.Size(event_shape))
inverse_layer = invert(DSCoupling(torch.Size(event_shape)))
forward_layer = DeepSigmoidalCoupling(torch.Size(event_shape))
inverse_layer = invert(DeepSigmoidalCoupling(torch.Size(event_shape)))

x = torch.randn(size=(*batch_shape, *event_shape)) # Reduce magnitude for stability
y, log_det_forward = forward_layer.forward(x)
Expand Down Expand Up @@ -109,15 +109,15 @@ def test_deep_sigmoid_coupling_flow(event_shape, batch_shape):
n_dim = int(torch.prod(torch.tensor(event_shape)))
event_shape = (n_dim,) # Overwrite

forward_flow = Flow(CouplingDSF(event_shape))
forward_flow = Flow(CouplingDeepSF(event_shape))
x = torch.randn(size=(*batch_shape, n_dim))
log_prob = forward_flow.log_prob(x)

assert log_prob.shape == batch_shape
assert torch.all(~torch.isnan(log_prob))
assert torch.all(~torch.isinf(log_prob))

inverse_flow = Flow(invert(CouplingDSF(event_shape)))
inverse_flow = Flow(invert(CouplingDeepSF(event_shape)))
x_new = inverse_flow.sample(len(x))

assert x_new.shape == (len(x), *inverse_flow.bijection.event_shape)
Expand Down
5 changes: 4 additions & 1 deletion torchflows/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
InverseAutoregressiveRQNSF,
CouplingLRS,
MaskedAutoregressiveLRS,
CouplingDSF,
InverseAutoregressiveLRS,
CouplingDeepSF,
CouplingDenseSF,
CouplingDeepDenseSF,
UMNNMAF
)

Expand Down
69 changes: 64 additions & 5 deletions torchflows/bijections/finite/autoregressive/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
RQSForwardMaskedAutoregressive,
RQSInverseMaskedAutoregressive,
InverseAffineCoupling,
DSCoupling,
DeepSigmoidalCoupling,
ElementwiseAffine,
UMNNMaskedAutoregressive,
LRSCoupling,
LRSForwardMaskedAutoregressive
LRSForwardMaskedAutoregressive,
LRSInverseMaskedAutoregressive,
DenseSigmoidalCoupling,
DeepDenseSigmoidalCoupling
)
from torchflows.bijections.base import BijectiveComposition
from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection, \
Expand Down Expand Up @@ -43,6 +46,7 @@ class NICE(BijectiveComposition):
Reference: Dinh et al. "NICE: Non-linear Independent Components Estimation" (2015); https://arxiv.org/abs/1410.8516.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
Expand All @@ -59,6 +63,7 @@ class RealNVP(BijectiveComposition):
Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
Expand All @@ -75,6 +80,7 @@ class InverseRealNVP(BijectiveComposition):
Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
Expand Down Expand Up @@ -117,6 +123,7 @@ class CouplingRQNSF(BijectiveComposition):
Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
Expand Down Expand Up @@ -146,6 +153,7 @@ class CouplingLRS(BijectiveComposition):
Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
Expand All @@ -162,6 +170,7 @@ class MaskedAutoregressiveLRS(BijectiveComposition):
Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168.
"""

def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
Expand All @@ -174,26 +183,75 @@ class InverseAutoregressiveRQNSF(BijectiveComposition):
Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032.
"""

def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(RQSInverseMaskedAutoregressive, event_shape, n_layers)
super().__init__(event_shape, bijections, **kwargs)


class CouplingDSF(BijectiveComposition):
"""Coupling deep sigmoidal flow (C-DSF) architecture.
class InverseAutoregressiveLRS(BijectiveComposition):
"""Inverse autoregressive linear rational spline (MA-LRS) architecture.
Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168.
"""

def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(LRSInverseMaskedAutoregressive, event_shape, n_layers)
super().__init__(event_shape, bijections, **kwargs)


class CouplingDeepSF(BijectiveComposition):
"""Coupling deep sigmoidal flow architecture.
Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list)
bijections = make_basic_layers(DeepSigmoidalCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


class CouplingDenseSF(BijectiveComposition):
"""Coupling dense sigmoidal flow architecture.
Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DenseSigmoidalCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


class CouplingDeepDenseSF(BijectiveComposition):
"""Coupling deep-dense sigmoidal flow architecture.
Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779.
"""

def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DeepDenseSigmoidalCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -202,6 +260,7 @@ class UMNNMAF(BijectiveComposition):
Reference: Wehenkel and Louppe "Unconstrained Monotonic Neural Networks" (2021); https://arxiv.org/abs/1908.05164.
"""

def __init__(self, event_shape, n_layers: int = 1, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
Expand Down
68 changes: 66 additions & 2 deletions torchflows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational
from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic
from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import (
DeepSigmoid
DeepSigmoid,
DenseSigmoid,
DeepDenseSigmoid
)
from torchflows.bijections.base import invert

Expand Down Expand Up @@ -181,7 +183,7 @@ def __init__(self,
super().__init__(transformer, coupling, conditioner_transform)


class DSCoupling(CouplingBijection):
class DeepSigmoidalCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
Expand All @@ -207,6 +209,58 @@ def __init__(self,
super().__init__(transformer, coupling, conditioner_transform)


class DenseSigmoidalCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_dense_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
coupling_kwargs: dict = None,
**kwargs):
if coupling_kwargs is None:
coupling_kwargs = dict()
coupling = make_coupling(event_shape, edge_list, **coupling_kwargs)
transformer = DenseSigmoid(
event_shape=torch.Size((coupling.target_event_size,)),
n_dense_layers=n_dense_layers
)
# Parameter order: [c1, c2, c3, c4, ..., ck] for all components
# Each component has parameter order [a_unc, b, w_unc]
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
context_shape=context_shape,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform)


class DeepDenseSigmoidalCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_hidden_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
coupling_kwargs: dict = None,
**kwargs):
if coupling_kwargs is None:
coupling_kwargs = dict()
coupling = make_coupling(event_shape, edge_list, **coupling_kwargs)
transformer = DeepDenseSigmoid(
event_shape=torch.Size((coupling.target_event_size,)),
n_hidden_layers=n_hidden_layers
)
# Parameter order: [c1, c2, c3, c4, ..., ck] for all components
# Each component has parameter order [a_unc, b, w_unc]
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
context_shape=context_shape,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform)


class LinearAffineCoupling(AffineCoupling):
def __init__(self, event_shape: torch.Size, **kwargs):
super().__init__(event_shape, **kwargs, n_layers=1)
Expand Down Expand Up @@ -276,6 +330,16 @@ def __init__(self,
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)


class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_bins: int = 8,
**kwargs):
transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins)
super().__init__(event_shape, context_shape, transformer=transformer, **kwargs)


class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: torch.Size,
Expand Down

0 comments on commit 49e265d

Please sign in to comment.