Skip to content

Commit

Permalink
Merge pull request #29 from davidnabergoj/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
davidnabergoj authored Nov 10, 2024
2 parents 58ddd76 + c4ab195 commit 60bfcee
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 77 deletions.
41 changes: 0 additions & 41 deletions docs/conf.py

This file was deleted.

5 changes: 0 additions & 5 deletions docs/requirements.txt

This file was deleted.

10 changes: 8 additions & 2 deletions docs/source/architectures/general_modeling.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
API for standard architectures
============================
=================================================
We lists notable implemented bijection architectures.
These all inherit from the Bijection class.

Expand Down Expand Up @@ -41,7 +41,13 @@ Autoregressive architectures
.. autoclass:: torchflows.architectures.InverseAutoregressiveLRS
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDSF
.. autoclass:: torchflows.architectures.CouplingDeepSF
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDenseSF
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDeepDenseSF
:members: __init__

.. autoclass:: torchflows.architectures.UMNNMAF
Expand Down
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import pathlib
import sys
import os

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

sys.path.insert(0, os.path.abspath('../../'))

with open("../DOCS_VERSION", 'r') as f:
version = f.read().strip()

Expand All @@ -24,6 +27,7 @@
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx_copybutton',
'sphinx.ext.viewcode',
]

exclude_patterns = []
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guides/image_modeling.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Image modeling
==============
=================

When modeling images, we can use specialized multiscale architectures which use convolutional neural network conditioners and specialized coupling schemes.
These architectures expect event shapes to be *(channels, height, width)*.
Expand Down
51 changes: 51 additions & 0 deletions docs/source/guides/numerical_stability.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Numerical stability
=============================

We may require a bijection to be very numerically precise when transforming data between original and latent spaces.
Given data `x`, bijection `b`, and tolerance `epsilon`, we may want:

.. code-block:: python
z, log_det_forward = b.forward(x)
x_reconstructed, log_det_inverse = b.inverse(z)
assert torch.all(torch.abs(x_reconstructed - x) < epsilon)
assert torch.all(torch.abs(log_det_forward + log_det_inverse)) < epsilon)
All architecture presets in Torchflows (with a defined forward and inverse pass) are tested to reconstruct inputs and log determinants.
We test reconstruction with inputs taken from a standard Gaussian distribution.
The specified tolerance is either 0.01 or 0.001, though many architectures achieve a lower reconstruction error.


Reducing reconstruction error
------------------------------------------

We may need an even smaller reconstruction error.
We can start by ensuring the input data is standardized:

.. code-block:: python
import torch
from torchflows.architectures import RealNVP
torch.manual_seed(0)
batch_shape = (5,)
event_shape = (10,)
x = (torch.randn(size=(*batch_shape, *event_shape)) * 12 + 35) ** 0.5
x_standardized = (x - x.mean()) / x.std()
real_nvp = RealNVP(event_shape)
def print_reconstruction_errors(bijection, inputs):
z, log_det_forward = bijection.forward(inputs)
inputs_reconstructed, log_det_inverse = bijection.inverse(z)
print(f'Data reconstruction error: {torch.max(torch.abs(inputs - inputs_reconstructed)):.8f}')
print(f'Log determinant error: {torch.max(torch.abs(log_det_forward + log_det_inverse)):.8f}')
# Test with non-standardized inputs
print_reconstruction_errors(real_nvp, x)
print('-------------------------------------------------------')
print_reconstruction_errors(real_nvp, x_standardized)
18 changes: 18 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: torchflows-dev
dependencies:
# Mandatory dependencies for torchflows
- python>=3.7
- pytorch::pytorch>=2.0.1
- numpy
- tqdm
- pip
- pytest # testing
- pip:
- torchdiffeq # mandatory for continuous NFs
- Pygments # docs
- Babel # docs
- sphinx>=8.0.1 # docs
- sphinx-rtd-theme # docs
- sphinx-copybutton # docs
- nbsphinx # docs
- sphinx-autoapi # docs
2 changes: 1 addition & 1 deletion torchflows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
else:
scale = torch.std(x, dim=list(range(n_batch_dims)))[..., None].to(self.value)
unconstrained_scale = self.transformer.unconstrain_scale(scale)
self.value.data = torch.concatenate([unconstrained_scale, shift], dim=-1).data
self.value.data = torch.cat([unconstrained_scale, shift], dim=-1).data
return super().forward(x, context)


Expand Down
5 changes: 2 additions & 3 deletions torchflows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,8 @@ class MaskedAutoregressiveBijection(AutoregressiveBijection):
Masked autoregressive bijection class.
This bijection is specified with a scalar transformer.
Its conditioner is always MADE, which receives as input a tensor x with x.shape = (*batch_shape, *event_shape).
MADE outputs parameters h for the scalar transformer with
h.shape = (*batch_shape, *event_shape, *parameter_shape_per_element).
Its conditioner is always MADE, which receives as input a tensor x with shape `(*batch_shape, *event_shape)`.
MADE outputs parameters h for the scalar transformer with shape `(*batch_shape, *event_shape, *parameter_shape_per_element)`.
The transformer then applies the bijection elementwise.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def inverse_sigmoid(p):

class Sigmoid(ScalarTransformer):
"""
Applies z = inv_sigmoid(w.T @ sigmoid(a * x + b)) where a > 0, w > 0 and sum(w) = 1.
Note: w, a, b are vectors, so multiplication a * x is broadcast.
Scalar transformer that applies sigmoid-based transformations.
Applies `z = inv_sigmoid(w.T @ sigmoid(a * x + b))` where `a > 0`, `w > 0` and `sum(w) = 1`.
Note: `w, a, b` are vectors, so multiplication `a * x` is broadcast.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]],
Expand All @@ -45,7 +47,9 @@ def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]:

def extract_parameters(self, h: torch.Tensor):
"""
h.shape = (batch_size, self.n_parameters)
Convert a parameter vector h into a tuple of parameters to be used in downstream transformations.
:param torch.Tensor h: parameter vector with shape `(batch_size, n_parameters)`.
"""
da = h[:, :self.hidden_dim]
db = h[:, self.hidden_dim:self.hidden_dim * 2]
Expand All @@ -60,13 +64,14 @@ def extract_parameters(self, h: torch.Tensor):

def forward_1d(self, x, h, eps: float = 1e-6):
"""
x.shape = (batch_size,)
h.shape = (batch_size, hidden_size * 3)
Apply forward transformation on input tensor with one event dimension.
For debug purposes - within this function, the following holds:
`a.shape = (batch_size, hidden_size)`, `b.shape = (batch_size, hidden_size)`, `w.shape = (batch_size, hidden_size)`.
Within the function:
a.shape = (batch_size, hidden_size)
b.shape = (batch_size, hidden_size)
w.shape = (batch_size, hidden_size)
:param torch.Tensor x: input tensor with shape `(batch_size,)`.
:param torch.Tensor h: parameter vector with shape `(batch_size, 3 * hidden_size)`.
:param float eps: small positive scalar.
"""
a, b, w, log_w = self.extract_parameters(h)
c = a * x[:, None] + b # (batch_size, n_hidden)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@


class Integration(ScalarTransformer):
"""
Base integration transformer class.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], bound: float = 100.0, eps: float = 1e-6):
"""
:param bound: specifies the initial interval [-bound, bound] where numerical inversion is performed.
Integration transformer constructor.
:param event_shape: shape of the input tensor.
:param bound: specifies the initial interval `[-bound, bound]` where numerical inversion is performed.
:param eps: small scalar value for transformer inversion.
"""
super().__init__(event_shape)
self.bound = bound
Expand All @@ -33,8 +41,10 @@ def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]):

def forward_1d(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
x.shape = (n,)
h.shape = (n, n_parameters)
Apply forward pass on input tensor with one event dimension.
:param torch.Tensor x: input tensor with shape `(n,)`.
:param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`.
"""
params = self.compute_parameters(h)
return self.base_forward_1d(x, params)
Expand All @@ -51,17 +61,21 @@ def inverse_1d_without_log_det(self, z: torch.Tensor, params: List[torch.Tensor]

def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
z.shape = (n,)
h.shape = (n, n_parameters)
Apply inverse pass on input tensor with one event dimension.
:param torch.Tensor x: input tensor with shape `(n,)`.
:param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`.
"""
params = self.compute_parameters(h)
x = self.inverse_without_log_det(z, params)
return x, -self.base_forward_1d(x, params)[1]

def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
x.shape = (*batch_shape, *event_shape)
h.shape = (*batch_shape, *parameter_shape)
Apply forward pass.
:param torch.Tensor x: input tensor with shape `(*batch_shape, *event_shape)`.
:param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`.
"""
z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element))
z = z_flat.view_as(x)
Expand All @@ -70,6 +84,12 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch
return z, log_det

def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply inverse pass.
:param torch.Tensor z: input tensor with shape `(*batch_shape, *event_shape)`.
:param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`.
"""
x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters_per_element))
x = x_flat.view_as(z)
batch_shape = get_batch_shape(z, self.event_shape)
Expand Down
12 changes: 12 additions & 0 deletions torchflows/bijections/finite/autoregressive/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@


class GaussLegendre(torch.autograd.Function):
"""
Autograd function that computes the Gauss-Legendre quadrature.
"""
@staticmethod
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, *h: List[torch.Tensor]) -> torch.Tensor:
"""
Forward autograd map.
"""
ctx.f, ctx.n = f, n
ctx.save_for_backward(a, b, *h)
return GaussLegendre.quadrature(f, a, b, n, h)

@staticmethod
def backward(ctx, grad_area: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
"""
Inverse autograd map.
"""
f, n = ctx.f, ctx.n
a, b, *h = ctx.saved_tensors

Expand Down Expand Up @@ -62,4 +71,7 @@ def nodes(n: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:


def gauss_legendre(f, a, b, n, h):
"""
Compute the Gauss-Legendre quadrature.
"""
return GaussLegendre.apply(f, a, b, n, *h)
2 changes: 1 addition & 1 deletion torchflows/bijections/finite/multiscale/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Shift, Affine
from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic
from torchflows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational
from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational
from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import (
DeepSigmoid,
DeepDenseSigmoid,
Expand Down
2 changes: 1 addition & 1 deletion torchflows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
assert height % 2 == 0
assert width % 2 == 0

out = torch.concatenate([
out = torch.cat([
x[..., ::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, ::2],
Expand Down
Loading

0 comments on commit 60bfcee

Please sign in to comment.