Skip to content

Commit

Permalink
Add AR-LRS
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 3, 2023
1 parent 70f9b30 commit d73111a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
InverseAffineCoupling,
DSCoupling,
ElementwiseAffine,
UMNNMaskedAutoregressive
UMNNMaskedAutoregressive,
LRSCoupling,
LRSForwardMaskedAutoregressive
)
from normalizing_flows.bijections.base import BijectiveComposition
from normalizing_flows.bijections.finite.linear import ReversePermutation
Expand Down Expand Up @@ -121,6 +123,34 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):
super().__init__(event_shape, bijections, **kwargs)


class CouplingLRS(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
LRSCoupling(event_shape=event_shape)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


class MaskedAutoregressiveLRS(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
LRSForwardMaskedAutoregressive(event_shape=event_shape)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


class InverseAutoregressiveRQNSF(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
Expand Down
21 changes: 21 additions & 0 deletions normalizing_flows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ def __init__(self,
conditioner_transform=conditioner_transform
)

class LRSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_bins: int = 8,
**kwargs):
transformer = LinearRational(event_shape=event_shape, n_bins=n_bins)
conditioner_transform = MADE(
input_event_shape=event_shape,
output_event_shape=event_shape,
n_predicted_parameters=transformer.n_parameters,
context_shape=context_shape,
**kwargs
)
conditioner = MaskedAutoregressive()
super().__init__(
conditioner=conditioner,
transformer=transformer,
conditioner_transform=conditioner_transform
)


class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection):
def __init__(self,
Expand Down

0 comments on commit d73111a

Please sign in to comment.