Skip to content

Commit

Permalink
Rework graphical coupling
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Feb 9, 2024
1 parent 00e9c80 commit c0104a2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
29 changes: 16 additions & 13 deletions normalizing_flows/bijections/finite/autoregressive/architectures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple, List

from normalizing_flows.bijections.finite.autoregressive.layers import (
ShiftCoupling,
AffineCoupling,
Expand All @@ -11,49 +13,50 @@
ElementwiseAffine,
UMNNMaskedAutoregressive,
LRSCoupling,
LRSForwardMaskedAutoregressive, ElementwiseShift
LRSForwardMaskedAutoregressive,
ElementwiseShift
)
from normalizing_flows.bijections.base import BijectiveComposition
from normalizing_flows.bijections.finite.linear import ReversePermutation


class NICE(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
ShiftCoupling(event_shape=event_shape)
ShiftCoupling(event_shape=event_shape, edge_list=edge_list)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


class RealNVP(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
AffineCoupling(event_shape=event_shape)
AffineCoupling(event_shape=event_shape, edge_list=edge_list)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


class InverseRealNVP(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
InverseAffineCoupling(event_shape=event_shape)
InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)
Expand Down Expand Up @@ -92,14 +95,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):


class CouplingRQNSF(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
RQSCoupling(event_shape=event_shape)
RQSCoupling(event_shape=event_shape, edge_list=edge_list)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)
Expand All @@ -124,14 +127,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):


class CouplingLRS(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseShift(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
LRSCoupling(event_shape=event_shape)
LRSCoupling(event_shape=event_shape, edge_list=edge_list)
])
bijections.append(ElementwiseShift(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)
Expand Down Expand Up @@ -166,14 +169,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):


class CouplingDSF(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters
DSCoupling(event_shape=event_shape, edge_list=edge_list) # TODO specify percent of global parameters
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,21 @@ def ignored_event_size(self):
return 0


class GraphicalCoupling(PartialCoupling):
def __init__(self, event_shape, edge_list: List[Tuple[int, int]]):
source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list]))))
target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list]))))
super().__init__(event_shape, source_mask, target_mask)


class HalfSplit(Coupling):
def __init__(self, event_shape):
event_size = int(torch.prod(torch.as_tensor(event_shape)))
super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2))


class GraphicalCoupling(PartialCoupling):
def __init__(self, event_shape, edge_list: List[Tuple[int, int]]):
source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list]))))
target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list]))))
super().__init__(event_shape, source_mask, target_mask)
def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None):
if edge_list is None:
return HalfSplit(event_shape)
else:
return GraphicalCoupling(event_shape, edge_list)
38 changes: 12 additions & 26 deletions normalizing_flows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling
from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \
InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift
Expand All @@ -18,7 +18,6 @@
from normalizing_flows.bijections.base import invert


# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files.
class ElementwiseAffine(ElementwiseBijection):
def __init__(self, event_shape, **kwargs):
transformer = Affine(event_shape, **kwargs)
Expand Down Expand Up @@ -47,29 +46,11 @@ class AffineCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
**kwargs):
if event_shape == (1,):
raise ValueError
coupling = HalfSplit(event_shape)
transformer = Affine(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
context_shape=context_shape,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform)


class GraphicalAffine(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
edge_list: List[Tuple[int, int]] = None,
context_shape: torch.Size = None,
**kwargs):
if event_shape == (1,):
raise ValueError
coupling = GraphicalCoupling(event_shape, edge_list=edge_list)
coupling = make_coupling(event_shape, edge_list)
transformer = Affine(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
Expand All @@ -84,10 +65,11 @@ class InverseAffineCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
if event_shape == (1,):
raise ValueError
coupling = HalfSplit(event_shape)
coupling = make_coupling(event_shape, edge_list)
transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))).invert()
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
Expand All @@ -102,8 +84,9 @@ class ShiftCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
coupling = HalfSplit(event_shape)
coupling = make_coupling(event_shape, edge_list)
transformer = Shift(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
Expand All @@ -119,9 +102,10 @@ def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_bins: int = 8,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
assert n_bins >= 1
coupling = HalfSplit(event_shape)
coupling = make_coupling(event_shape, edge_list)
transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins)
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
Expand All @@ -137,8 +121,9 @@ def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_bins: int = 8,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
coupling = HalfSplit(event_shape)
coupling = make_coupling(event_shape, edge_list)
transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins)
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
Expand All @@ -154,8 +139,9 @@ def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_hidden_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
**kwargs):
coupling = HalfSplit(event_shape)
coupling = make_coupling(event_shape, edge_list)
transformer = DeepSigmoid(
event_shape=torch.Size((coupling.target_event_size,)),
n_hidden_layers=n_hidden_layers
Expand Down

0 comments on commit c0104a2

Please sign in to comment.