From 00e9c8057c8810063c34d14583f9106f5e6917f5 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 11:49:19 +0100 Subject: [PATCH] Add graphical affine coupling layer --- .../conditioning/coupling_masks.py | 9 ++++++++ .../finite/autoregressive/layers.py | 23 ++++++++++++++++++- .../finite/autoregressive/layers_base.py | 4 ++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 7b829f2..422c8dc 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -1,3 +1,5 @@ +from typing import Tuple, List + import torch @@ -54,3 +56,10 @@ class HalfSplit(Coupling): def __init__(self, event_shape): event_size = int(torch.prod(torch.as_tensor(event_shape))) super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) + + +class GraphicalCoupling(PartialCoupling): + def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): + source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + super().__init__(event_shape, source_mask, target_mask) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index ca44d40..2b29ba3 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,7 +1,9 @@ +from typing import Tuple, List + import torch from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift @@ -59,6 +61,25 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class GraphicalAffine(CouplingBijection): + def __init__(self, + event_shape: torch.Size, + edge_list: List[Tuple[int, int]] = None, + context_shape: torch.Size = None, + **kwargs): + if event_shape == (1,): + raise ValueError + coupling = GraphicalCoupling(event_shape, edge_list=edge_list) + transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + context_shape=context_shape, + **kwargs + ) + super().__init__(transformer, coupling, conditioner_transform) + + class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 2bb1cb4..0b18b34 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -5,7 +5,7 @@ from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \ MADE -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import Coupling +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import PartialCoupling from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape @@ -52,7 +52,7 @@ class CouplingBijection(AutoregressiveBijection): def __init__(self, transformer: TensorTransformer, - coupling: Coupling, + coupling: PartialCoupling, conditioner_transform: ConditionerTransform, **kwargs): super().__init__(coupling.event_shape, transformer, conditioner_transform, **kwargs)