From c0104a27dff754617ffd49c3d3929e8fc8ee9a76 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:13:28 +0100 Subject: [PATCH] Rework graphical coupling --- .../finite/autoregressive/architectures.py | 29 +++++++------- .../conditioning/coupling_masks.py | 17 ++++++--- .../finite/autoregressive/layers.py | 38 ++++++------------- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index ca31f0a..8ad5838 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -1,3 +1,5 @@ +from typing import Tuple, List + from normalizing_flows.bijections.finite.autoregressive.layers import ( ShiftCoupling, AffineCoupling, @@ -11,49 +13,50 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive, ElementwiseShift + LRSForwardMaskedAutoregressive, + ElementwiseShift ) from normalizing_flows.bijections.base import BijectiveComposition from normalizing_flows.bijections.finite.linear import ReversePermutation class NICE(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - ShiftCoupling(event_shape=event_shape) + ShiftCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) class RealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - AffineCoupling(event_shape=event_shape) + AffineCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) class InverseRealNVP(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - InverseAffineCoupling(event_shape=event_shape) + InverseAffineCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -92,14 +95,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingRQNSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - RQSCoupling(event_shape=event_shape) + RQSCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -124,14 +127,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingLRS(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - LRSCoupling(event_shape=event_shape) + LRSCoupling(event_shape=event_shape, edge_list=edge_list) ]) bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) @@ -166,14 +169,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters + DSCoupling(event_shape=event_shape, edge_list=edge_list) # TODO specify percent of global parameters ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 422c8dc..1621732 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -52,14 +52,21 @@ def ignored_event_size(self): return 0 +class GraphicalCoupling(PartialCoupling): + def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): + source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + super().__init__(event_shape, source_mask, target_mask) + + class HalfSplit(Coupling): def __init__(self, event_shape): event_size = int(torch.prod(torch.as_tensor(event_shape))) super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2)) -class GraphicalCoupling(PartialCoupling): - def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): - source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) - target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) - super().__init__(event_shape, source_mask, target_mask) +def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None): + if edge_list is None: + return HalfSplit(event_shape) + else: + return GraphicalCoupling(event_shape, edge_list) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 2b29ba3..d6d613e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -3,7 +3,7 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit, GraphicalCoupling +from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift @@ -18,7 +18,6 @@ from normalizing_flows.bijections.base import invert -# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files. class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) @@ -47,29 +46,11 @@ class AffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, - **kwargs): - if event_shape == (1,): - raise ValueError - coupling = HalfSplit(event_shape) - transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) - conditioner_transform = FeedForward( - input_event_shape=torch.Size((coupling.source_event_size,)), - parameter_shape=torch.Size(transformer.parameter_shape), - context_shape=context_shape, - **kwargs - ) - super().__init__(transformer, coupling, conditioner_transform) - - -class GraphicalAffine(CouplingBijection): - def __init__(self, - event_shape: torch.Size, edge_list: List[Tuple[int, int]] = None, - context_shape: torch.Size = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = GraphicalCoupling(event_shape, edge_list=edge_list) + coupling = make_coupling(event_shape, edge_list) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -84,10 +65,11 @@ class InverseAffineCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, **kwargs): if event_shape == (1,): raise ValueError - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = Affine(event_shape=torch.Size((coupling.target_event_size,))).invert() conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -102,8 +84,9 @@ class ShiftCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = Shift(event_shape=torch.Size((coupling.target_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -119,9 +102,10 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, **kwargs): assert n_bins >= 1 - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = LinearRational(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -137,8 +121,9 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = RationalQuadratic(event_shape=torch.Size((coupling.target_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling.source_event_size,)), @@ -154,8 +139,9 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, **kwargs): - coupling = HalfSplit(event_shape) + coupling = make_coupling(event_shape, edge_list) transformer = DeepSigmoid( event_shape=torch.Size((coupling.target_event_size,)), n_hidden_layers=n_hidden_layers