From 1a5a57a2cd12df19e4fdf62b7c7086fe99efaa43 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 29 Aug 2024 20:18:52 +0200 Subject: [PATCH] Fix global conditioner predictions, make 80% of dense NAF outputs globally learned --- ...st_globally_learned_conditioner_outputs.py | 29 +++++++++ test/test_reconstruction_bijections.py | 7 ++- .../finite/autoregressive/architectures.py | 59 ++++++++++++++++--- .../autoregressive/conditioning/transforms.py | 39 ++++++++---- .../finite/autoregressive/layers.py | 56 ++++++++++++++++-- 5 files changed, 164 insertions(+), 26 deletions(-) create mode 100644 test/test_globally_learned_conditioner_outputs.py diff --git a/test/test_globally_learned_conditioner_outputs.py b/test/test_globally_learned_conditioner_outputs.py new file mode 100644 index 0000000..0a2dd0f --- /dev/null +++ b/test/test_globally_learned_conditioner_outputs.py @@ -0,0 +1,29 @@ +import torch + +from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward + + +def test_standard(): + torch.manual_seed(0) + + input_event_shape = torch.Size((10, 10)) + parameter_shape = torch.Size((20, 3)) + test_inputs = torch.randn(100, *input_event_shape) + + t = FeedForward(input_event_shape, parameter_shape) + output = t(test_inputs) + + assert output.shape == (100, *parameter_shape) + + +def test_eighty_pct_global(): + torch.manual_seed(0) + + input_event_shape = torch.Size((10, 10)) + parameter_shape = torch.Size((20, 3)) + test_inputs = torch.randn(100, *input_event_shape) + + t = FeedForward(input_event_shape, parameter_shape, percentage_global_parameters=0.8) + output = t(test_inputs) + + assert output.shape == (100, *parameter_shape) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index ae9d5fc..2aaaffd 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -11,7 +11,7 @@ from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ - LRSCoupling, LinearRQSCoupling, ActNorm + LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock @@ -145,7 +145,10 @@ def test_linear(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tup RealNVP, CouplingRQNSF, LRSCoupling, - LinearRQSCoupling + LinearRQSCoupling, + DenseSigmoidalCoupling, + DeepDenseSigmoidalCoupling, + DeepSigmoidalCoupling, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index a1c8a5d..1aba1e7 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -30,7 +30,8 @@ def make_basic_layers(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], event_shape, n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None): + edge_list: List[Tuple[int, int]] = None, + **kwargs): """ Returns a list of bijections for transformations of vectors. """ @@ -38,7 +39,7 @@ def make_basic_layers(base_bijection: Type[ for _ in range(n_layers): if edge_list is None: bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list)) + bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list, **kwargs)) bijections.append(ActNorm(event_shape=event_shape)) bijections.append(ElementwiseAffine(event_shape=event_shape)) bijections.append(ActNorm(event_shape=event_shape)) @@ -269,10 +270,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DenseSigmoidalCoupling, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DenseSigmoidalCoupling, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -286,10 +294,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DenseSigmoidalInverseMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -303,10 +318,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DenseSigmoidalForwardMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -320,10 +342,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DeepDenseSigmoidalCoupling, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DeepDenseSigmoidalCoupling, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -337,10 +366,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DeepDenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DeepDenseSigmoidalInverseMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -354,10 +390,17 @@ def __init__(self, event_shape, n_layers: int = 2, edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DeepDenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + bijections = make_basic_layers( + DeepDenseSigmoidalForwardMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) diff --git a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py index 6d73834..9c330a0 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union, Type +from typing import Tuple, Union, Type, Optional import torch import torch.nn as nn @@ -25,7 +25,7 @@ def __init__(self, context_shape: Union[torch.Size, Tuple[int, ...]], parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, - global_parameter_mask: torch.Tensor = None, + global_parameter_mask: Optional[torch.Tensor] = None, initial_global_parameter_value: float = None, **kwargs): """ @@ -61,7 +61,10 @@ def __init__(self, self.parameter_shape = parameter_shape self.global_parameter_mask = global_parameter_mask self.n_transformer_parameters = int(torch.prod(torch.as_tensor(self.parameter_shape))) - self.n_global_parameters = 0 if global_parameter_mask is None else int(torch.sum(self.global_parameter_mask)) + if global_parameter_mask is None: + self.n_global_parameters = 0 + else: + self.n_global_parameters = int(torch.sum(global_parameter_mask)) self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters if initial_global_parameter_value is None: @@ -84,12 +87,12 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): else: if self.n_global_parameters == self.n_transformer_parameters: # All transformer parameters are learned globally - output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device) + output = torch.zeros(*batch_shape, *self.parameter_shape).to(x) output[..., self.global_parameter_mask] = self.global_theta_flat return output else: # Some transformer parameters are learned globally, some are predicted - output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device) + output = torch.zeros(*batch_shape, *self.parameter_shape).to(x) output[..., self.global_parameter_mask] = self.global_theta_flat output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context) return output @@ -129,12 +132,28 @@ def __init__(self, input_event_shape: Union[torch.Size, Tuple[int, ...]], parameter_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, + percentage_global_parameters: float = 0.0, **kwargs): + if 0.0 < percentage_global_parameters <= 1.0: + n_parameters = int(torch.prod(torch.as_tensor(parameter_shape))) + parameter_permutation = torch.randperm(n_parameters) + global_param_indices = parameter_permutation[:int(n_parameters * percentage_global_parameters)] + global_mask = torch.zeros(size=(n_parameters,), dtype=torch.bool) + global_mask[global_param_indices] = True + global_mask = global_mask.view(*parameter_shape) + else: + global_mask = None + super().__init__( input_event_shape=input_event_shape, parameter_shape=parameter_shape, context_shape=context_shape, - **kwargs + **{ + **kwargs, + **dict( + global_parameter_mask=global_mask + ) + } ) @@ -255,7 +274,7 @@ def __init__(self, ) if n_hidden is None: - n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) + n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4) layers = [] if n_layers == 1: @@ -267,7 +286,7 @@ def __init__(self, layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError - layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,))) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): @@ -313,7 +332,7 @@ def __init__(self, ) if n_hidden is None: - n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) + n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4) if n_layers <= 2: raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") @@ -322,7 +341,7 @@ def __init__(self, for _ in range(n_layers - 2): layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size, nonlinearity=nonlinearity)) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) - layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,))) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index ba75787..727ca5d 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -243,6 +243,7 @@ def __init__(self, n_dense_layers: int = 2, edge_list: List[Tuple[int, int]] = None, coupling_kwargs: dict = None, + percentage_global_parameters: float = 0.8, **kwargs): if coupling_kwargs is None: coupling_kwargs = dict() @@ -257,7 +258,10 @@ def __init__(self, input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, - **kwargs + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } ) super().__init__(transformer, coupling, conditioner_transform) @@ -267,12 +271,21 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_dense_layers: int = 2, + percentage_global_parameters: float = 0.8, **kwargs): transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), n_dense_layers=n_dense_layers ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): @@ -280,12 +293,21 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_dense_layers: int = 2, + percentage_global_parameters: float = 0.8, **kwargs): transformer: ScalarTransformer = DenseSigmoid( event_shape=torch.Size(event_shape), n_dense_layers=n_dense_layers ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) class DeepDenseSigmoidalCoupling(CouplingBijection): @@ -295,6 +317,7 @@ def __init__(self, n_hidden_layers: int = 2, edge_list: List[Tuple[int, int]] = None, coupling_kwargs: dict = None, + percentage_global_parameters: float = 0.8, **kwargs): if coupling_kwargs is None: coupling_kwargs = dict() @@ -309,7 +332,10 @@ def __init__(self, input_event_shape=torch.Size((coupling.source_event_size,)), parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, - **kwargs + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } ) super().__init__(transformer, coupling, conditioner_transform) @@ -319,12 +345,21 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 2, + percentage_global_parameters: float = 0.8, **kwargs): transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), n_hidden_layers=n_hidden_layers ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): @@ -332,12 +367,21 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_hidden_layers: int = 2, + percentage_global_parameters: float = 0.8, **kwargs): transformer: ScalarTransformer = DeepDenseSigmoid( event_shape=torch.Size(event_shape), n_hidden_layers=n_hidden_layers ) - super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) class LinearAffineCoupling(AffineCoupling):