Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements and fixes #6

Merged
merged 31 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6eec626
Fix device
davidnabergoj Dec 20, 2023
76db053
Fix weight shape
davidnabergoj Dec 20, 2023
304092e
Change defaults for LRS
davidnabergoj Dec 20, 2023
f0167a3
Change default bin size in LRS
davidnabergoj Dec 21, 2023
9f22ac8
Reduce number of tests for faster CI/CD
davidnabergoj Dec 21, 2023
8ac96f8
Rewrite elementwise bijection to avoid null conditioner
davidnabergoj Dec 25, 2023
8a8572b
Remove Conditioner and NullConditioner classes
davidnabergoj Dec 25, 2023
cc607f6
Rename files
davidnabergoj Dec 25, 2023
c0e239e
Towards fixing autograd in residual flows
davidnabergoj Dec 27, 2023
6a17859
Fix autograd for InvertibleResNet
davidnabergoj Dec 27, 2023
952b3a2
Fixing autograd for ResFlow
davidnabergoj Dec 27, 2023
3df7362
Create graph in roulette estimator only if training
davidnabergoj Dec 27, 2023
43168a5
Set default hidden layer size of spectral neural network to max(log(n…
davidnabergoj Dec 27, 2023
a50432a
Change default hidden layer size in proximal neural network
davidnabergoj Dec 27, 2023
8ee1e7c
Set default hidden size in OTFlow
davidnabergoj Dec 27, 2023
e82e79a
Fix OTFlow forward output sign
davidnabergoj Dec 27, 2023
9624725
Handle training=True for residual bijections in Flow.fit
davidnabergoj Dec 27, 2023
549e5c0
Fix divergence sign for FFJORD, RNODE
davidnabergoj Dec 27, 2023
c8be5c1
Fix gamma in Proximal ResFlow with single layer blocks
davidnabergoj Dec 27, 2023
b736d6f
Change defaults for invertible ResNet and ResFlow
davidnabergoj Dec 27, 2023
abb32dc
Fix log determinant in residual bijection, interleave layers with ele…
davidnabergoj Dec 27, 2023
2560814
Use different residual blocks in Invertible ResNet and ResFlow
davidnabergoj Dec 27, 2023
5aaf37a
Change ProximalResFlow defaults
davidnabergoj Dec 27, 2023
fd4dad8
Fix Gaussian quadrature call for UMNN-MAF
davidnabergoj Dec 27, 2023
3e708b1
Better skip message for residual bijection reconstruction test
davidnabergoj Dec 27, 2023
3026654
Change UMNN MAF parametrization
davidnabergoj Dec 27, 2023
8dcea8b
Add variational fit
davidnabergoj Jan 8, 2024
b17da39
Fix SVI, add docs
davidnabergoj Jan 8, 2024
9b87f8f
Skip UMNN tests
davidnabergoj Jan 8, 2024
19ddb80
Merge branch 'dev' of https://github.com/davidnabergoj/normalizing-fl…
davidnabergoj Jan 8, 2024
f555263
Skip UMNN transformer tests
davidnabergoj Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/Variational inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
We show how to fit normalizing flows using stochastic variational inference (SVI). Whereas traditional maximum
likelihood estimation requires a fixed dataset of samples, SVI lets us optimize NF parameters with the unnormalized
target log density function.

As an example, we define the unnormalized log density of a diagonal Gaussian. We assume this target has 10 dimensions
with mean 5 and variance 9 in each dimension:

```python
import torch

torch.manual_seed(0)

event_shape = (10,)
true_mean = torch.full(size=event_shape, fill_value=5.0)
true_variance = torch.full(size=event_shape, fill_value=9.0)


def target_log_prob(x: torch.Tensor):
return torch.sum(-((x - true_mean) ** 2 / (2 * true_variance)), dim=1)
```

We define the flow and run the variational fit:

```python
from normalizing_flows import Flow
from normalizing_flows.bijections import RealNVP

torch.manual_seed(0)
flow = Flow(RealNVP(event_shape=event_shape))
flow.variational_fit(target_log_prob, show_progress=True)
```

We plot samples from the trained flow. We also print estimated marginal means and variances. We see that the estimates are roughly accurate.
```python
import matplotlib.pyplot as plt

torch.manual_seed(0)
x_flow = flow.sample(10000).detach()

plt.figure()
plt.scatter(x_flow[:, 0], x_flow[:, 1])
plt.show()

print(f'{torch.mean(x_flow, dim=0) = }')
print(f'{torch.var(x_flow, dim=0) = }')
```
2 changes: 1 addition & 1 deletion normalizing_flows/bijections/continuous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(self, t, states):
s_.requires_grad_(True)
dy = self.diffeq(t, y, *states[2:])
divergence = self.divergence_step(dy, y)
return tuple([dy, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])
return tuple([dy, divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])


class RegularizedApproximateODEFunction(ApproximateODEFunction):
Expand Down
35 changes: 21 additions & 14 deletions normalizing_flows/bijections/continuous/otflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Union, Tuple

import torch
Expand Down Expand Up @@ -29,19 +30,21 @@ def __init__(self, event_size: int, hidden_size: int, step_size: float = 0.01):

divisor = max(event_size ** 2, 10)

K0_delta = torch.randn(size=(hidden_size, event_size)) / divisor
b0_delta = torch.randn(size=(hidden_size,)) / divisor
self.K0_delta = nn.Parameter(torch.randn(size=(hidden_size, event_size)) / divisor)
self.b0 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

K1_delta = torch.randn(size=(hidden_size, hidden_size)) / divisor
b1_delta = torch.randn(size=(hidden_size,)) / divisor
self.K1_delta = nn.Parameter(torch.randn(size=(hidden_size, hidden_size)) / divisor)
self.b1 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

self.K0 = nn.Parameter(torch.eye(hidden_size, event_size) + K0_delta)
self.b0 = nn.Parameter(0 + b0_delta)
self.step_size = step_size

self.K1 = nn.Parameter(torch.eye(hidden_size, hidden_size) + K1_delta)
self.b1 = nn.Parameter(0 + b1_delta)
@property
def K0(self):
return torch.eye(*self.K0_delta.shape) + self.K0_delta / 1000

self.step_size = step_size
@property
def K1(self):
return torch.eye(*self.K1_delta.shape) + self.K1_delta / 1000

@staticmethod
def sigma(x):
Expand Down Expand Up @@ -114,7 +117,7 @@ def hessian_trace(self,

t0 = torch.sum(
torch.multiply(
(self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1),
self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1,
torch.nn.functional.linear(ones, self.K0[:, :-1] ** 2)
),
dim=1
Expand All @@ -138,9 +141,13 @@ def hessian_trace(self,


class OTPotential(TimeDerivative):
def __init__(self, event_size: int, hidden_size: int, **kwargs):
def __init__(self, event_size: int, hidden_size: int = None, **kwargs):
super().__init__()

# hidden_size = m
if hidden_size is None:
hidden_size = max(math.log(event_size), 4)

r = min(10, event_size)

# Initialize w to 1
Expand All @@ -159,7 +166,7 @@ def __init__(self, event_size: int, hidden_size: int, **kwargs):
self.resnet = OTResNet(event_size + 1, hidden_size, **kwargs) # (x, t) has d+1 elements

def forward(self, t, x):
return -self.gradient(concatenate_x_t(x, t))
return self.gradient(concatenate_x_t(x, t))

def gradient(self, s):
# Equation 12
Expand Down Expand Up @@ -187,8 +194,8 @@ def hessian_trace(self, s: torch.Tensor, u0: torch.Tensor = None, z1: torch.Tens


class OTFlowODEFunction(ExactODEFunction):
def __init__(self, n_dim):
super().__init__(OTPotential(n_dim, hidden_size=30))
def __init__(self, n_dim, **kwargs):
super().__init__(OTPotential(n_dim, **kwargs))

def compute_log_det(self, t, x):
return self.diffeq.hessian_trace(concatenate_x_t(x, t)).view(-1, 1) # Need an empty dim at the end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):


class UMNNMAF(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
def __init__(self, event_shape, n_layers: int = 1, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from typing import Tuple, Union, Type, List
from typing import Tuple, Union

import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.context import Concatenation, ContextCombiner, \
from normalizing_flows.bijections.finite.autoregressive.conditioning.context import Concatenation, ContextCombiner, \
Bypass
from normalizing_flows.utils import get_batch_shape, pad_leading_dims
from normalizing_flows.utils import get_batch_shape


class ConditionerTransform(nn.Module):
Expand Down Expand Up @@ -83,12 +83,12 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
else:
if self.n_global_parameters == self.n_transformer_parameters:
# All transformer parameters are learned globally
output = torch.zeros(*batch_shape, *self.parameter_shape)
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output[..., self.global_parameter_mask] = self.global_theta_flat
return output
else:
# Some transformer parameters are learned globally, some are predicted
output = torch.zeros(*batch_shape, *self.parameter_shape)
output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device)
output[..., self.global_parameter_mask] = self.global_theta_flat
output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context)
return output
Expand Down
8 changes: 4 additions & 4 deletions normalizing_flows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import HalfSplit
from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \
InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift
Expand Down Expand Up @@ -223,8 +223,8 @@ class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection):
def __init__(self,
event_shape: torch.Size,
context_shape: torch.Size = None,
n_hidden_layers: int = 1,
hidden_dim: int = 5,
n_hidden_layers: int = None,
hidden_dim: int = None,
**kwargs):
transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork(
event_shape=event_shape,
Expand Down
34 changes: 24 additions & 10 deletions normalizing_flows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Tuple, Optional, Union
from typing import Tuple, Union

import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner
from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform, \
MADE
from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask
from normalizing_flows.bijections.finite.autoregressive.conditioning.coupling_masks import CouplingMask
from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer
from normalizing_flows.bijections.base import Bijection
from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape
Expand All @@ -14,12 +14,10 @@
class AutoregressiveBijection(Bijection):
def __init__(self,
event_shape,
conditioner: Optional[Conditioner],
transformer: Union[TensorTransformer, ScalarTransformer],
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(event_shape=event_shape)
self.conditioner = conditioner
self.conditioner_transform = conditioner_transform
self.transformer = transformer

Expand Down Expand Up @@ -57,7 +55,7 @@ def __init__(self,
coupling_mask: CouplingMask,
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs)
super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs)
self.coupling_mask = coupling_mask

assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,)
Expand Down Expand Up @@ -112,7 +110,7 @@ def __init__(self,
context_shape=context_shape,
**kwargs
)
super().__init__(transformer.event_shape, None, transformer, conditioner_transform)
super().__init__(transformer.event_shape, transformer, conditioner_transform)

def apply_conditioner_transformer(self, inputs, context, forward: bool = True):
h = self.conditioner_transform(inputs, context)
Expand Down Expand Up @@ -160,7 +158,23 @@ class ElementwiseBijection(AutoregressiveBijection):
def __init__(self, transformer: ScalarTransformer, fill_value: float = None):
super().__init__(
transformer.event_shape,
NullConditioner(),
transformer,
Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value)
None
)

if fill_value is None:
self.value = nn.Parameter(torch.randn(*transformer.parameter_shape))
else:
self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value))

def prepare_h(self, batch_shape):
tmp = self.value[[None] * len(batch_shape)]
return tmp.repeat(*batch_shape, *([1] * len(self.transformer.parameter_shape)))

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(x, self.event_shape))
return self.transformer.forward(x, h)

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(z, self.event_shape))
return self.transformer.inverse(z, h)
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@ class UnconstrainedMonotonicNeuralNetwork(UnconstrainedMonotonicTransformer):
"""
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_hidden_layers: int = 2,
n_hidden_layers: int = None,
hidden_dim: int = None):
super().__init__(event_shape, g=self.neural_network_forward, c=torch.tensor(-100.0))

if n_hidden_layers is None:
n_hidden_layers = 1
self.n_hidden_layers = n_hidden_layers

if hidden_dim is None:
hidden_dim = max(5 * int(math.log(self.n_dim)), 4)
hidden_dim = max(int(math.log(self.n_dim)), 4)
self.hidden_dim = hidden_dim
self.const = 1000 # for stability

self.const = 1 # for stability

# weight, bias have self.hidden_dim elements
self.n_input_params = 2 * self.hidden_dim
Expand Down Expand Up @@ -118,10 +123,8 @@ def neural_network_forward(inputs, parameters: List[torch.Tensor]):

def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
x_r = x.view(-1, 1, 1)
integral_flat = self.integral(x_r, params)
log_det_flat = self.g(x_r, params).log() # We can apply log since g is always positive
output = integral_flat.view_as(x)
log_det = log_det_flat.view_as(x)
output = self.integral(x_r, params).view_as(x)
log_det = self.g(x_r, params).log().view_as(x) # We can apply log since g is always positive
return output, log_det

def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def compute_knots(self, u_x, u_y, u_l, u_d):
# u_y acts as a delta
# u_d acts as a delta
knots_x = self.compute_bins(u_x, self.min_input, self.max_input, self.min_bin_width)
knots_y = self.compute_bins(u_x + u_y / 1000, self.min_output, self.max_output, self.min_bin_height)
knots_y = self.compute_bins(u_x + u_y / 100, self.min_output, self.max_output, self.min_bin_height)
knots_lambda = torch.sigmoid(u_l)
knots_d = self.compute_derivatives(self.const + u_d / 1000)
knots_d = self.compute_derivatives(self.const + u_d / 100)
return knots_x, knots_y, knots_d, knots_lambda

def forward_1d(self, x, h):
Expand Down
8 changes: 4 additions & 4 deletions normalizing_flows/bijections/finite/autoregressive/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Tuple, List
from typing import List, Optional, Tuple

import numpy as np
import torch.autograd
Expand All @@ -10,13 +10,13 @@

class GaussLegendre(torch.autograd.Function):
@staticmethod
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, h: List[torch.Tensor]) -> torch.Tensor:
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, *h: List[torch.Tensor]) -> torch.Tensor:
ctx.f, ctx.n = f, n
ctx.save_for_backward(a, b, *h)
return GaussLegendre.quadrature(f, a, b, n, h)

@staticmethod
def backward(ctx, grad_area: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def backward(ctx, grad_area: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
f, n = ctx.f, ctx.n
a, b, *h = ctx.saved_tensors

Expand Down Expand Up @@ -62,4 +62,4 @@ def nodes(n: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:


def gauss_legendre(f, a, b, n, h):
return GaussLegendre.apply(f, a, b, n, h)
return GaussLegendre.apply(f, a, b, n, *h)
Loading