Skip to content

Commit

Permalink
Fix global conditioner predictions, make 80% of dense NAF outputs glo…
Browse files Browse the repository at this point in the history
…bally learned
  • Loading branch information
davidnabergoj committed Aug 29, 2024
1 parent bd14e13 commit 1a5a57a
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 26 deletions.
29 changes: 29 additions & 0 deletions test/test_globally_learned_conditioner_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch

from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward


def test_standard():
torch.manual_seed(0)

input_event_shape = torch.Size((10, 10))
parameter_shape = torch.Size((20, 3))
test_inputs = torch.randn(100, *input_event_shape)

t = FeedForward(input_event_shape, parameter_shape)
output = t(test_inputs)

assert output.shape == (100, *parameter_shape)


def test_eighty_pct_global():
torch.manual_seed(0)

input_event_shape = torch.Size((10, 10))
parameter_shape = torch.Size((20, 3))
test_inputs = torch.randn(100, *input_event_shape)

t = FeedForward(input_event_shape, parameter_shape, percentage_global_parameters=0.8)
output = t(test_inputs)

assert output.shape == (100, *parameter_shape)
7 changes: 5 additions & 2 deletions test/test_reconstruction_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \
InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF
from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
LRSCoupling, LinearRQSCoupling, ActNorm
LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling
from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
Expand Down Expand Up @@ -145,7 +145,10 @@ def test_linear(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tup
RealNVP,
CouplingRQNSF,
LRSCoupling,
LinearRQSCoupling
LinearRQSCoupling,
DenseSigmoidalCoupling,
DeepDenseSigmoidalCoupling,
DeepSigmoidalCoupling,
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
Expand Down
59 changes: 51 additions & 8 deletions torchflows/bijections/finite/autoregressive/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ def make_basic_layers(base_bijection: Type[
Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]],
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None):
edge_list: List[Tuple[int, int]] = None,
**kwargs):
"""
Returns a list of bijections for transformations of vectors.
"""
bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
if edge_list is None:
bijections.append(ReversePermutation(event_shape=event_shape))
bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list))
bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list, **kwargs))
bijections.append(ActNorm(event_shape=event_shape))
bijections.append(ElementwiseAffine(event_shape=event_shape))
bijections.append(ActNorm(event_shape=event_shape))
Expand Down Expand Up @@ -269,10 +270,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DenseSigmoidalCoupling, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DenseSigmoidalCoupling,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -286,10 +294,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DenseSigmoidalInverseMaskedAutoregressive,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -303,10 +318,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DenseSigmoidalForwardMaskedAutoregressive,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -320,10 +342,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DeepDenseSigmoidalCoupling, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DeepDenseSigmoidalCoupling,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -337,10 +366,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DeepDenseSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DeepDenseSigmoidalInverseMaskedAutoregressive,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -354,10 +390,17 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
percentage_global_parameters: float = 0.8,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_basic_layers(DeepDenseSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list)
bijections = make_basic_layers(
DeepDenseSigmoidalForwardMaskedAutoregressive,
event_shape,
n_layers,
edge_list,
percentage_global_parameters=percentage_global_parameters
)
super().__init__(event_shape, bijections, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Tuple, Union, Type
from typing import Tuple, Union, Type, Optional

import torch
import torch.nn as nn
Expand All @@ -25,7 +25,7 @@ def __init__(self,
context_shape: Union[torch.Size, Tuple[int, ...]],
parameter_shape: Union[torch.Size, Tuple[int, ...]],
context_combiner: ContextCombiner = None,
global_parameter_mask: torch.Tensor = None,
global_parameter_mask: Optional[torch.Tensor] = None,
initial_global_parameter_value: float = None,
**kwargs):
"""
Expand Down Expand Up @@ -61,7 +61,10 @@ def __init__(self,
self.parameter_shape = parameter_shape
self.global_parameter_mask = global_parameter_mask
self.n_transformer_parameters = int(torch.prod(torch.as_tensor(self.parameter_shape)))
self.n_global_parameters = 0 if global_parameter_mask is None else int(torch.sum(self.global_parameter_mask))
if global_parameter_mask is None:
self.n_global_parameters = 0
else:
self.n_global_parameters = int(torch.sum(global_parameter_mask))
self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters

if initial_global_parameter_value is None:
Expand All @@ -84,12 +87,12 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
else:
if self.n_global_parameters == self.n_transformer_parameters:
# All transformer parameters are learned globally
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output = torch.zeros(*batch_shape, *self.parameter_shape).to(x)
output[..., self.global_parameter_mask] = self.global_theta_flat
return output
else:
# Some transformer parameters are learned globally, some are predicted
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output = torch.zeros(*batch_shape, *self.parameter_shape).to(x)
output[..., self.global_parameter_mask] = self.global_theta_flat
output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context)
return output
Expand Down Expand Up @@ -129,12 +132,28 @@ def __init__(self,
input_event_shape: Union[torch.Size, Tuple[int, ...]],
parameter_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
percentage_global_parameters: float = 0.0,
**kwargs):
if 0.0 < percentage_global_parameters <= 1.0:
n_parameters = int(torch.prod(torch.as_tensor(parameter_shape)))
parameter_permutation = torch.randperm(n_parameters)
global_param_indices = parameter_permutation[:int(n_parameters * percentage_global_parameters)]
global_mask = torch.zeros(size=(n_parameters,), dtype=torch.bool)
global_mask[global_param_indices] = True
global_mask = global_mask.view(*parameter_shape)
else:
global_mask = None

super().__init__(
input_event_shape=input_event_shape,
parameter_shape=parameter_shape,
context_shape=context_shape,
**kwargs
**{
**kwargs,
**dict(
global_parameter_mask=global_mask
)
}
)


Expand Down Expand Up @@ -255,7 +274,7 @@ def __init__(self,
)

if n_hidden is None:
n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4)
n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4)

layers = []
if n_layers == 1:
Expand All @@ -267,7 +286,7 @@ def __init__(self,
layers.append(nn.Linear(n_hidden, self.n_predicted_parameters))
else:
raise ValueError
layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape))
layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,)))
self.sequential = nn.Sequential(*layers)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None):
Expand Down Expand Up @@ -313,7 +332,7 @@ def __init__(self,
)

if n_hidden is None:
n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4)
n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4)

if n_layers <= 2:
raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}")
Expand All @@ -322,7 +341,7 @@ def __init__(self,
for _ in range(n_layers - 2):
layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size, nonlinearity=nonlinearity))
layers.append(nn.Linear(n_hidden, self.n_predicted_parameters))
layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape))
layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,)))
self.sequential = nn.Sequential(*layers)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None):
Expand Down
Loading

0 comments on commit 1a5a57a

Please sign in to comment.