From 9e40e62e2a553e5d655f4414725aefe1eac63127 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 14 Nov 2023 15:43:07 -0800 Subject: [PATCH] Add option to learn global parameters in ConditionerTransform --- .../autoregressive/conditioner_transforms.py | 119 ++++++++++++------ .../finite/autoregressive/layers.py | 24 ++-- test/test_conditioner_transforms.py | 2 +- 3 files changed, 95 insertions(+), 50 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 1bf15e5..0ac8240 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -9,12 +9,39 @@ class ConditionerTransform(nn.Module): + """ + Module which predicts transformer parameters for the transformation of a tensor y using an input tensor x and + possibly a corresponding context tensor c. + + In other words, a conditioner transform f predicts theta = f(x, c) to be used in transformer g with z = g(y; theta). + The transformation g is performed elementwise on tensor y. + Since g transforms each element of y with a parameter tensor of shape (n_transformer_parameters,), + the shape of theta is (*y.shape, n_transformer_parameters). + """ + def __init__(self, input_event_shape, context_shape, output_event_shape, - n_predicted_parameters: int, - context_combiner: ContextCombiner = None): + n_transformer_parameters: int, + context_combiner: ContextCombiner = None, + percent_globally_learned_parameters: float = 0.0, + initial_global_parameter_value: float = None): + """ + :param input_event_shape: shape of conditioner input tensor x. + :param context_shape: shape of conditioner context tensor c. + :param output_event_shape: shape of transformer input tensor y. + :param n_transformer_parameters: number of parameters required to transform a single element of y. + :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. + :param percent_globally_learned_parameters: fraction of all parameters in theta that should be learned directly. + A value of 0 means the conditioner predicts n_transformer_parameters parameters based on x and c. + A value of 1 means the conditioner predicts no parameters based on x and c, but outputs globally learned theta. + A value of alpha means the conditioner outputs alpha * n_transformer_parameters parameters globally and + predicts the rest. In this case, the predicted parameters are the last alpha * n_transformer_parameters + elements in theta. + :param initial_global_parameter_value: the initial value for the entire globally learned part of theta. If None, + the global part of theta is initialized to samples from the standard normal distribution. + """ super().__init__() if context_shape is None: context_combiner = Bypass(input_event_shape) @@ -28,12 +55,38 @@ def __init__(self, self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) - self.n_predicted_parameters = n_predicted_parameters + self.n_transformer_parameters = n_transformer_parameters + self.n_globally_learned_parameters = int(n_transformer_parameters * percent_globally_learned_parameters) + self.n_predicted_parameters = self.n_transformer_parameters - self.n_globally_learned_parameters + + if initial_global_parameter_value is None: + initial_global_theta = torch.randn(size=(*output_event_shape, self.n_globally_learned_parameters)) + else: + initial_global_theta = torch.full( + size=(*output_event_shape, self.n_globally_learned_parameters), + fill_value=initial_global_parameter_value + ) + self.global_theta = nn.Parameter(initial_global_theta) def forward(self, x: torch.Tensor, context: torch.Tensor = None): # x.shape = (*batch_shape, *input_event_shape) # context.shape = (*batch_shape, *context_shape) - # output.shape = (*batch_shape, *output_event_shape, n_predicted_parameters) + # output.shape = (*batch_shape, *output_event_shape, n_transformer_parameters) + if self.n_globally_learned_parameters == 0: + return self.predict_theta(x, context) + else: + n_batch_dims = len(x.shape) - len(self.output_event_shape) + n_event_dims = len(self.output_event_shape) + batch_shape = x.shape[:n_batch_dims] + batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat( + *batch_shape, *([1] * n_event_dims), 1 + ) + if self.n_globally_learned_parameters == self.n_transformer_parameters: + return batch_global_theta + else: + return torch.cat([batch_global_theta, self.predict_theta(x, context)], dim=-1) + + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError @@ -43,19 +96,10 @@ def __init__(self, output_event_shape, n_parameters: int, fill_value: float = No input_event_shape=None, context_shape=None, output_event_shape=output_event_shape, - n_predicted_parameters=n_parameters + n_transformer_parameters=n_parameters, + initial_global_parameter_value=fill_value, + percent_globally_learned_parameters=1.0 ) - if fill_value is None: - initial_theta = torch.randn(size=(*self.output_event_shape, n_parameters,)) - else: - initial_theta = torch.full(size=(*self.output_event_shape, n_parameters), fill_value=fill_value) - self.theta = nn.Parameter(initial_theta) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - n_batch_dims = len(x.shape) - len(self.output_event_shape) - n_event_dims = len(self.output_event_shape) - batch_shape = x.shape[:n_batch_dims] - return pad_leading_dims(self.theta, n_batch_dims).repeat(*batch_shape, *([1] * n_event_dims), 1) class MADE(ConditionerTransform): @@ -70,7 +114,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2): @@ -78,7 +122,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_transformer_parameters ) if n_hidden is None: @@ -103,10 +147,10 @@ def __init__(self, layers.extend([ self.MaskedLinear( masks[-1].shape[1], - masks[-1].shape[0] * n_predicted_parameters, - torch.repeat_interleave(masks[-1], n_predicted_parameters, dim=0) + masks[-1].shape[0] * self.n_predicted_parameters, + torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) ), - nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters)) + nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters)) ]) self.sequential = nn.Sequential(*layers) @@ -123,23 +167,23 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_predicted_parameters: int, + def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_transformer_parameters: int, **kwargs): - super().__init__(input_event_shape, output_event_shape, n_predicted_parameters, n_layers=1, **kwargs) + super().__init__(input_event_shape, output_event_shape, n_transformer_parameters, n_layers=1, **kwargs) class FeedForward(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2): @@ -147,7 +191,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_transformer_parameters ) if n_hidden is None: @@ -161,20 +205,20 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_transformer_parameters)) elif n_layers > 1: layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) - layers.append(nn.Linear(n_hidden, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(n_hidden, self.n_output_event_dims * self.n_predicted_parameters)) else: raise ValueError # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) @@ -197,10 +241,11 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, - n_layers: int = 2): - super().__init__(input_event_shape, context_shape, output_event_shape, n_predicted_parameters) + n_layers: int = 2, + **kwargs): + super().__init__(input_event_shape, context_shape, output_event_shape, n_transformer_parameters, **kwargs) # If context given, concatenate it to transform input if context_shape is not None: @@ -210,20 +255,20 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) elif n_layers > 1: layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) else: raise ValueError # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index e9e203f..6ab4c79 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -52,7 +52,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -71,7 +71,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -88,7 +88,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -107,7 +107,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -125,7 +125,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -145,7 +145,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -181,7 +181,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -203,7 +203,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -224,7 +224,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -245,7 +245,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -268,7 +268,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -295,7 +295,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index e68514d..c74f0b9 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -24,7 +24,7 @@ def test_shape(transform_class, batch_shape, input_event_shape, output_event_sha transform = transform_class( input_event_shape=input_event_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_predicted_parameters ) out = transform(x) assert out.shape == (*batch_shape, *output_event_shape, n_predicted_parameters)