Skip to content

Commit

Permalink
Update docs and environment.yml
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 10, 2024
1 parent 98b2abb commit ed37e74
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 72 deletions.
41 changes: 0 additions & 41 deletions docs/conf.py

This file was deleted.

5 changes: 0 additions & 5 deletions docs/requirements.txt

This file was deleted.

10 changes: 8 additions & 2 deletions docs/source/architectures/general_modeling.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
API for standard architectures
============================
=================================================
We lists notable implemented bijection architectures.
These all inherit from the Bijection class.

Expand Down Expand Up @@ -41,7 +41,13 @@ Autoregressive architectures
.. autoclass:: torchflows.architectures.InverseAutoregressiveLRS
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDSF
.. autoclass:: torchflows.architectures.CouplingDeepSF
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDenseSF
:members: __init__

.. autoclass:: torchflows.architectures.CouplingDeepDenseSF
:members: __init__

.. autoclass:: torchflows.architectures.UMNNMAF
Expand Down
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import pathlib
import sys
import os

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

sys.path.insert(0, os.path.abspath('../../'))

with open("../DOCS_VERSION", 'r') as f:
version = f.read().strip()

Expand All @@ -24,6 +27,7 @@
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx_copybutton',
'sphinx.ext.viewcode',
]

exclude_patterns = []
Expand Down
20 changes: 20 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: torchflows-dev
dependencies:
# Mandatory dependencies for torchflows
- python>=3.7
- pytorch::pytorch>=2.0.1
- numpy
- tqdm
- pip
- pip:
- torchdiffeq
# Dependencies for testing and docs
- pytest
- pip:
- Pygments
- Babel
- sphinx>=8.0.1
- sphinx-rtd-theme
- sphinx-copybutton
- nbsphinx
- sphinx-autoapi
2 changes: 1 addition & 1 deletion torchflows/bijections/finite/autoregressive/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
else:
scale = torch.std(x, dim=list(range(n_batch_dims)))[..., None].to(self.value)
unconstrained_scale = self.transformer.unconstrain_scale(scale)
self.value.data = torch.concatenate([unconstrained_scale, shift], dim=-1).data
self.value.data = torch.cat([unconstrained_scale, shift], dim=-1).data
return super().forward(x, context)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def inverse_sigmoid(p):

class Sigmoid(ScalarTransformer):
"""
Applies z = inv_sigmoid(w.T @ sigmoid(a * x + b)) where a > 0, w > 0 and sum(w) = 1.
Note: w, a, b are vectors, so multiplication a * x is broadcast.
Scalar transformer that applies sigmoid-based transformations.
Applies `z = inv_sigmoid(w.T @ sigmoid(a * x + b))` where `a > 0`, `w > 0` and `sum(w) = 1`.
Note: `w, a, b` are vectors, so multiplication `a * x` is broadcast.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]],
Expand All @@ -45,7 +47,9 @@ def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]:

def extract_parameters(self, h: torch.Tensor):
"""
h.shape = (batch_size, self.n_parameters)
Convert a parameter vector h into a tuple of parameters to be used in downstream transformations.
:param torch.Tensor h: parameter vector with shape `(batch_size, n_parameters)`.
"""
da = h[:, :self.hidden_dim]
db = h[:, self.hidden_dim:self.hidden_dim * 2]
Expand All @@ -60,13 +64,14 @@ def extract_parameters(self, h: torch.Tensor):

def forward_1d(self, x, h, eps: float = 1e-6):
"""
x.shape = (batch_size,)
h.shape = (batch_size, hidden_size * 3)
Apply forward transformation on input tensor with one event dimension.
For debug purposes - within this function, the following holds:
`a.shape = (batch_size, hidden_size)`, `b.shape = (batch_size, hidden_size)`, `w.shape = (batch_size, hidden_size)`.
Within the function:
a.shape = (batch_size, hidden_size)
b.shape = (batch_size, hidden_size)
w.shape = (batch_size, hidden_size)
:param torch.Tensor x: input tensor with shape `(batch_size,)`.
:param torch.Tensor h: parameter vector with shape `(batch_size, 3 * hidden_size)`.
:param float eps: small positive scalar.
"""
a, b, w, log_w = self.extract_parameters(h)
c = a * x[:, None] + b # (batch_size, n_hidden)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@


class Integration(ScalarTransformer):
"""
Base integration transformer class.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], bound: float = 100.0, eps: float = 1e-6):
"""
:param bound: specifies the initial interval [-bound, bound] where numerical inversion is performed.
Integration transformer constructor.
:param event_shape: shape of the input tensor.
:param bound: specifies the initial interval `[-bound, bound]` where numerical inversion is performed.
:param eps: small scalar value for transformer inversion.
"""
super().__init__(event_shape)
self.bound = bound
Expand All @@ -33,8 +41,10 @@ def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]):

def forward_1d(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
x.shape = (n,)
h.shape = (n, n_parameters)
Apply forward pass on input tensor with one event dimension.
:param torch.Tensor x: input tensor with shape `(n,)`.
:param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`.
"""
params = self.compute_parameters(h)
return self.base_forward_1d(x, params)
Expand All @@ -51,17 +61,21 @@ def inverse_1d_without_log_det(self, z: torch.Tensor, params: List[torch.Tensor]

def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
z.shape = (n,)
h.shape = (n, n_parameters)
Apply inverse pass on input tensor with one event dimension.
:param torch.Tensor x: input tensor with shape `(n,)`.
:param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`.
"""
params = self.compute_parameters(h)
x = self.inverse_without_log_det(z, params)
return x, -self.base_forward_1d(x, params)[1]

def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
x.shape = (*batch_shape, *event_shape)
h.shape = (*batch_shape, *parameter_shape)
Apply forward pass.
:param torch.Tensor x: input tensor with shape `(*batch_shape, *event_shape)`.
:param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`.
"""
z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element))
z = z_flat.view_as(x)
Expand All @@ -70,6 +84,12 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch
return z, log_det

def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply inverse pass.
:param torch.Tensor z: input tensor with shape `(*batch_shape, *event_shape)`.
:param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`.
"""
x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters_per_element))
x = x_flat.view_as(z)
batch_shape = get_batch_shape(z, self.event_shape)
Expand Down
12 changes: 12 additions & 0 deletions torchflows/bijections/finite/autoregressive/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@


class GaussLegendre(torch.autograd.Function):
"""
Autograd function that computes the Gauss-Legendre quadrature.
"""
@staticmethod
def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, *h: List[torch.Tensor]) -> torch.Tensor:
"""
Forward autograd map.
"""
ctx.f, ctx.n = f, n
ctx.save_for_backward(a, b, *h)
return GaussLegendre.quadrature(f, a, b, n, h)

@staticmethod
def backward(ctx, grad_area: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
"""
Inverse autograd map.
"""
f, n = ctx.f, ctx.n
a, b, *h = ctx.saved_tensors

Expand Down Expand Up @@ -62,4 +71,7 @@ def nodes(n: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:


def gauss_legendre(f, a, b, n, h):
"""
Compute the Gauss-Legendre quadrature.
"""
return GaussLegendre.apply(f, a, b, n, *h)
2 changes: 1 addition & 1 deletion torchflows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
assert height % 2 == 0
assert width % 2 == 0

out = torch.concatenate([
out = torch.cat([
x[..., ::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, ::2],
Expand Down
26 changes: 21 additions & 5 deletions torchflows/bijections/numerical_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@


class Bisection(torch.autograd.Function):
"""
Autograd bisection function.
"""
@staticmethod
def forward(ctx: Any, f, y, a: torch.Tensor, b: torch.Tensor, n: int, h: List[torch.Tensor]) -> torch.Tensor:
"""
Forward method for the autograd function.
"""
ctx.f = f
ctx.save_for_backward(*h)
for _ in range(n):
Expand All @@ -17,6 +23,9 @@ def forward(ctx: Any, f, y, a: torch.Tensor, b: torch.Tensor, n: int, h: List[to

@staticmethod
def backward(ctx: Any, grad_x: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""
Backward method for the autograd function.
"""
f, x = ctx.f, ctx.x
h = ctx.saved_tensors
with torch.enable_grad():
Expand All @@ -33,6 +42,9 @@ def backward(ctx: Any, grad_x: torch.Tensor) -> Tuple[Union[torch.Tensor, None],


def bisection(f, y, a, b, n, h):
"""
Apply bisection with autograd support.
"""
return Bisection.apply(f, y, a.to(y), b.to(y), n, h)


Expand All @@ -43,12 +55,16 @@ def bisection_no_gradient(f: callable,
n_iterations: int = 500,
atol: float = 1e-9):
"""
Find x that satisfies f(x) = y.
We assume x.shape == y.shape.
f is applied elementwise.
Apply bisection without autograd support.
Explanation: find x that satisfies f(x) = y. We assume x.shape == y.shape. f is applied elementwise.
a: lower bound.
b: upper bound.
:param f: function that takes as input a tensor and produces as output z.
:param y: value to match to z.
:param a: lower bound for bisection search.
:param b: upper bound for bisection search.
:param n_iterations: number of bisection iterations.
:param atol: absolute tolerance.
"""

if a is None:
Expand Down

0 comments on commit ed37e74

Please sign in to comment.