Skip to content

Commit

Permalink
Merge pull request #6 from davidnabergoj/dev
Browse files Browse the repository at this point in the history
Performance improvements and fixes
  • Loading branch information
davidnabergoj authored Jan 8, 2024
2 parents e71c6f9 + f555263 commit 7b83bd0
Show file tree
Hide file tree
Showing 31 changed files with 270 additions and 172 deletions.
46 changes: 46 additions & 0 deletions examples/Variational inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
We show how to fit normalizing flows using stochastic variational inference (SVI). Whereas traditional maximum
likelihood estimation requires a fixed dataset of samples, SVI lets us optimize NF parameters with the unnormalized
target log density function.

As an example, we define the unnormalized log density of a diagonal Gaussian. We assume this target has 10 dimensions
with mean 5 and variance 9 in each dimension:

```python
import torch

torch.manual_seed(0)

event_shape = (10,)
true_mean = torch.full(size=event_shape, fill_value=5.0)
true_variance = torch.full(size=event_shape, fill_value=9.0)


def target_log_prob(x: torch.Tensor):
return torch.sum(-((x - true_mean) ** 2 / (2 * true_variance)), dim=1)
```

We define the flow and run the variational fit:

```python
from normalizing_flows import Flow
from normalizing_flows.bijections import RealNVP

torch.manual_seed(0)
flow = Flow(RealNVP(event_shape=event_shape))
flow.variational_fit(target_log_prob, show_progress=True)
```

We plot samples from the trained flow. We also print estimated marginal means and variances. We see that the estimates are roughly accurate.
```python
import matplotlib.pyplot as plt

torch.manual_seed(0)
x_flow = flow.sample(10000).detach()

plt.figure()
plt.scatter(x_flow[:, 0], x_flow[:, 1])
plt.show()

print(f'{torch.mean(x_flow, dim=0) = }')
print(f'{torch.var(x_flow, dim=0) = }')
```
2 changes: 1 addition & 1 deletion normalizing_flows/bijections/continuous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(self, t, states):
s_.requires_grad_(True)
dy = self.diffeq(t, y, *states[2:])
divergence = self.divergence_step(dy, y)
return tuple([dy, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])
return tuple([dy, divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])


class RegularizedApproximateODEFunction(ApproximateODEFunction):
Expand Down
35 changes: 21 additions & 14 deletions normalizing_flows/bijections/continuous/otflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Union, Tuple

import torch
Expand Down Expand Up @@ -29,19 +30,21 @@ def __init__(self, event_size: int, hidden_size: int, step_size: float = 0.01):

divisor = max(event_size ** 2, 10)

K0_delta = torch.randn(size=(hidden_size, event_size)) / divisor
b0_delta = torch.randn(size=(hidden_size,)) / divisor
self.K0_delta = nn.Parameter(torch.randn(size=(hidden_size, event_size)) / divisor)
self.b0 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

K1_delta = torch.randn(size=(hidden_size, hidden_size)) / divisor
b1_delta = torch.randn(size=(hidden_size,)) / divisor
self.K1_delta = nn.Parameter(torch.randn(size=(hidden_size, hidden_size)) / divisor)
self.b1 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

self.K0 = nn.Parameter(torch.eye(hidden_size, event_size) + K0_delta)
self.b0 = nn.Parameter(0 + b0_delta)
self.step_size = step_size

self.K1 = nn.Parameter(torch.eye(hidden_size, hidden_size) + K1_delta)
self.b1 = nn.Parameter(0 + b1_delta)
@property
def K0(self):
return torch.eye(*self.K0_delta.shape) + self.K0_delta / 1000

self.step_size = step_size
@property
def K1(self):
return torch.eye(*self.K1_delta.shape) + self.K1_delta / 1000

@staticmethod
def sigma(x):
Expand Down Expand Up @@ -114,7 +117,7 @@ def hessian_trace(self,

t0 = torch.sum(
torch.multiply(
(self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1),
self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1,
torch.nn.functional.linear(ones, self.K0[:, :-1] ** 2)
),
dim=1
Expand All @@ -138,9 +141,13 @@ def hessian_trace(self,


class OTPotential(TimeDerivative):
def __init__(self, event_size: int, hidden_size: int, **kwargs):
def __init__(self, event_size: int, hidden_size: int = None, **kwargs):
super().__init__()

# hidden_size = m
if hidden_size is None:
hidden_size = max(math.log(event_size), 4)

r = min(10, event_size)

# Initialize w to 1
Expand All @@ -159,7 +166,7 @@ def __init__(self, event_size: int, hidden_size: int, **kwargs):
self.resnet = OTResNet(event_size + 1, hidden_size, **kwargs) # (x, t) has d+1 elements

def forward(self, t, x):
return -self.gradient(concatenate_x_t(x, t))
return self.gradient(concatenate_x_t(x, t))

def gradient(self, s):
# Equation 12
Expand Down Expand Up @@ -187,8 +194,8 @@ def hessian_trace(self, s: torch.Tensor, u0: torch.Tensor = None, z1: torch.Tens


class OTFlowODEFunction(ExactODEFunction):
def __init__(self, n_dim):
super().__init__(OTPotential(n_dim, hidden_size=30))
def __init__(self, n_dim, **kwargs):
super().__init__(OTPotential(n_dim, **kwargs))

def compute_log_det(self, t, x):
return self.diffeq.hessian_trace(concatenate_x_t(x, t)).view(-1, 1) # Need an empty dim at the end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):


class UMNNMAF(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 1, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
Expand Down

This file was deleted.

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from typing import Tuple, Union, Type, List
from typing import Tuple, Union

import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.context import Concatenation, ContextCombiner, \
from normalizing_flows.bijections.finite.autoregressive.conditioning.context import Concatenation, ContextCombiner, \
Bypass
from normalizing_flows.utils import get_batch_shape, pad_leading_dims
from normalizing_flows.utils import get_batch_shape


class ConditionerTransform(nn.Module):
Expand Down Expand Up @@ -83,12 +83,12 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
else:
if self.n_global_parameters == self.n_transformer_parameters:
# All transformer parameters are learned globally
output = torch.zeros(*batch_shape, *self.parameter_shape)
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output[..., self.global_parameter_mask] = self.global_theta_flat
return output
else:
# Some transformer parameters are learned globally, some are predicted
output = torch.zeros(*batch_shape, *self.parameter_shape)
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output[..., self.global_parameter_mask] = self.global_theta_flat
output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context)
return output
Expand Down
8 changes: 4 additions & 4 deletions normalizing_flows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit
from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \
InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift
Expand Down Expand Up @@ -223,8 +223,8 @@ class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_hidden_layers: int = 1,
hidden_dim: int = 5,
n_hidden_layers: int = None,
hidden_dim: int = None,
**kwargs):
transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork(
event_shape=event_shape,
Expand Down
34 changes: 24 additions & 10 deletions normalizing_flows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Tuple, Optional, Union
from typing import Tuple, Union

import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner
from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \
MADE
from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import CouplingMask
from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer
from normalizing_flows.bijections.base import Bijection
from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape
Expand All @@ -14,12 +14,10 @@
class AutoregressiveBijection(Bijection):
def __init__(self,
event_shape,
conditioner: Optional[Conditioner],
transformer: Union[TensorTransformer, ScalarTransformer],
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(event_shape=event_shape)
self.conditioner = conditioner
self.conditioner_transform = conditioner_transform
self.transformer = transformer

Expand Down Expand Up @@ -57,7 +55,7 @@ def __init__(self,
coupling_mask: CouplingMask,
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs)
super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs)
self.coupling_mask = coupling_mask

assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,)
Expand Down Expand Up @@ -112,7 +110,7 @@ def __init__(self,
context_shape=context_shape,
**kwargs
)
super().__init__(transformer.event_shape, None, transformer, conditioner_transform)
super().__init__(transformer.event_shape, transformer, conditioner_transform)

def apply_conditioner_transformer(self, inputs, context, forward: bool = True):
h = self.conditioner_transform(inputs, context)
Expand Down Expand Up @@ -160,7 +158,23 @@ class ElementwiseBijection(AutoregressiveBijection):
def __init__(self, transformer: ScalarTransformer, fill_value: float = None):
super().__init__(
transformer.event_shape,
NullConditioner(),
transformer,
Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value)
None
)

if fill_value is None:
self.value = nn.Parameter(torch.randn(*transformer.parameter_shape))
else:
self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value))

def prepare_h(self, batch_shape):
tmp = self.value[[None] * len(batch_shape)]
return tmp.repeat(*batch_shape, *([1] * len(self.transformer.parameter_shape)))

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(x, self.event_shape))
return self.transformer.forward(x, h)

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(z, self.event_shape))
return self.transformer.inverse(z, h)
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ class UnconstrainedMonotonicNeuralNetwork(UnconstrainedMonotonicTransformer):
"""
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_hidden_layers: int = 2,
n_hidden_layers: int = None,
hidden_dim: int = None):
super().__init__(event_shape, g=self.neural_network_forward, c=torch.tensor(-100.0))

if n_hidden_layers is None:
n_hidden_layers = 1
self.n_hidden_layers = n_hidden_layers

if hidden_dim is None:
hidden_dim = max(5 * int(math.log(self.n_dim)), 4)
hidden_dim = max(int(math.log(self.n_dim)), 4)
self.hidden_dim = hidden_dim
self.const = 1000 # for stability

self.const = 1 # for stability

# weight, bias have self.hidden_dim elements
self.n_input_params = 2 * self.hidden_dim
Expand Down Expand Up @@ -118,10 +123,8 @@ def neural_network_forward(inputs, parameters: List[torch.Tensor]):

def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
x_r = x.view(-1, 1, 1)
integral_flat = self.integral(x_r, params)
log_det_flat = self.g(x_r, params).log() # We can apply log since g is always positive
output = integral_flat.view_as(x)
log_det = log_det_flat.view_as(x)
output = self.integral(x_r, params).view_as(x)
log_det = self.g(x_r, params).log().view_as(x) # We can apply log since g is always positive
return output, log_det

def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def compute_knots(self, u_x, u_y, u_l, u_d):
# u_y acts as a delta
# u_d acts as a delta
knots_x = self.compute_bins(u_x, self.min_input, self.max_input, self.min_bin_width)
knots_y = self.compute_bins(u_x + u_y / 1000, self.min_output, self.max_output, self.min_bin_height)
knots_y = self.compute_bins(u_x + u_y / 100, self.min_output, self.max_output, self.min_bin_height)
knots_lambda = torch.sigmoid(u_l)
knots_d = self.compute_derivatives(self.const + u_d / 1000)
knots_d = self.compute_derivatives(self.const + u_d / 100)
return knots_x, knots_y, knots_d, knots_lambda

def forward_1d(self, x, h):
Expand Down
8 changes: 4 additions & 4 deletions normalizing_flows/bijections/finite/autoregressive/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Tuple, List
from typing import List, Optional, Tuple

import numpy as np
import torch.autograd
Expand All @@ -10,13 +10,13 @@

class GaussLegendre(torch.autograd.Function):
@staticmethod
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, h: List[torch.Tensor]) -> torch.Tensor:
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, *h: List[torch.Tensor]) -> torch.Tensor:
ctx.f, ctx.n = f, n
ctx.save_for_backward(a, b, *h)
return GaussLegendre.quadrature(f, a, b, n, h)

@staticmethod
def backward(ctx, grad_area: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def backward(ctx, grad_area: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
f, n = ctx.f, ctx.n
a, b, *h = ctx.saved_tensors

Expand Down Expand Up @@ -62,4 +62,4 @@ def nodes(n: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:


def gauss_legendre(f, a, b, n, h):
return GaussLegendre.apply(f, a, b, n, h)
return GaussLegendre.apply(f, a, b, n, *h)
Loading

0 comments on commit 7b83bd0

Please sign in to comment.