Skip to content

Commit

Permalink
Add graphical affine coupling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Feb 9, 2024
1 parent a43097e commit 00e9c80
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple, List

import torch


Expand Down Expand Up @@ -54,3 +56,10 @@ class HalfSplit(Coupling):
def __init__(self, event_shape):
event_size = int(torch.prod(torch.as_tensor(event_shape)))
super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2))


class GraphicalCoupling(PartialCoupling):
def __init__(self, event_shape, edge_list: List[Tuple[int, int]]):
source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list]))))
target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list]))))
super().__init__(event_shape, source_mask, target_mask)
23 changes: 22 additions & 1 deletion normalizing_flows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Tuple, List

import torch

from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling
from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \
InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift
Expand Down Expand Up @@ -59,6 +61,25 @@ def __init__(self,
super().__init__(transformer, coupling, conditioner_transform)


class GraphicalAffine(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
edge_list: List[Tuple[int, int]] = None,
context_shape: torch.Size = None,
**kwargs):
if event_shape == (1,):
raise ValueError
coupling = GraphicalCoupling(event_shape, edge_list=edge_list)
transformer = Affine(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
context_shape=context_shape,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform)


class InverseAffineCoupling(CouplingBijection):
def __init__(self,
event_shape: torch.Size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \
MADE
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling
from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer
from normalizing_flows.bijections.base import Bijection
from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape
Expand Down Expand Up @@ -52,7 +52,7 @@ class CouplingBijection(AutoregressiveBijection):

def __init__(self,
transformer: TensorTransformer,
coupling: Coupling,
coupling: PartialCoupling,
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs)
Expand Down

0 comments on commit 00e9c80

Please sign in to comment.