diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a6aebfd..f7926e7 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -56,7 +56,7 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 github-release: - name: Sign the distributions with Sigstore and upload to Github release + name: Upload to Github release needs: - publish-to-pypi runs-on: ubuntu-latest @@ -69,12 +69,6 @@ jobs: with: name: python-package-distributions path: dist/ - - name: Sign the distributions with Sigstore - uses: sigstore/gh-action-sigstore-python@v2.1.1 - with: - inputs: >- - ./dist/*.tar.gz - ./dist/*.whl - name: Create Github Release env: GITHUB_TOKEN: ${{ github.token }} diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index bf47c93..0000000 --- a/docs/conf.py +++ /dev/null @@ -1,41 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -import pathlib - -# -- Project information - -with open("DOCS_VERSION", 'r') as f: - version = f.read().strip() - -release = version - -project = 'torchflows' -copyright = '2024, David Nabergoj' -author = 'David Nabergoj' - -# release = f'{torchflows.__version__}' -# version = f'{torchflows.__version__}' - -# -- General configuration - -extensions = [ - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', -] - -intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), -} -intersphinx_disabled_domains = ['std'] - -templates_path = ['_templates'] - -# -- Options for HTML output - -html_theme = 'sphinx_rtd_theme' - -# -- Options for EPUB output -epub_show_urls = 'footnote' diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 2c4c704..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -sphinx==7.1.2 -sphinx-rtd-theme==1.3.0rc1 -torchflows>=1.0.2 -sphinx-copybutton -nbsphinx \ No newline at end of file diff --git a/docs/source/architectures/general_modeling.rst b/docs/source/architectures/general_modeling.rst index b1274f3..1171162 100644 --- a/docs/source/architectures/general_modeling.rst +++ b/docs/source/architectures/general_modeling.rst @@ -1,5 +1,5 @@ API for standard architectures -============================ +================================================= We lists notable implemented bijection architectures. These all inherit from the Bijection class. @@ -41,7 +41,13 @@ Autoregressive architectures .. autoclass:: torchflows.architectures.InverseAutoregressiveLRS :members: __init__ -.. autoclass:: torchflows.architectures.CouplingDSF +.. autoclass:: torchflows.architectures.CouplingDeepSF + :members: __init__ + +.. autoclass:: torchflows.architectures.CouplingDenseSF + :members: __init__ + +.. autoclass:: torchflows.architectures.CouplingDeepDenseSF :members: __init__ .. autoclass:: torchflows.architectures.UMNNMAF diff --git a/docs/source/conf.py b/docs/source/conf.py index 4211bb7..997ed31 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2,11 +2,14 @@ # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import pathlib +import sys +import os # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information +sys.path.insert(0, os.path.abspath('../../')) + with open("../DOCS_VERSION", 'r') as f: version = f.read().strip() @@ -24,6 +27,7 @@ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx_copybutton', + 'sphinx.ext.viewcode', ] exclude_patterns = [] diff --git a/docs/source/guides/image_modeling.rst b/docs/source/guides/image_modeling.rst index 134d8d5..f0103b5 100644 --- a/docs/source/guides/image_modeling.rst +++ b/docs/source/guides/image_modeling.rst @@ -1,5 +1,5 @@ Image modeling -============== +================= When modeling images, we can use specialized multiscale architectures which use convolutional neural network conditioners and specialized coupling schemes. These architectures expect event shapes to be *(channels, height, width)*. diff --git a/docs/source/guides/numerical_stability.rst b/docs/source/guides/numerical_stability.rst new file mode 100644 index 0000000..9b32893 --- /dev/null +++ b/docs/source/guides/numerical_stability.rst @@ -0,0 +1,51 @@ +Numerical stability +============================= + +We may require a bijection to be very numerically precise when transforming data between original and latent spaces. +Given data `x`, bijection `b`, and tolerance `epsilon`, we may want: + +.. code-block:: python + + z, log_det_forward = b.forward(x) + x_reconstructed, log_det_inverse = b.inverse(z) + + assert torch.all(torch.abs(x_reconstructed - x) < epsilon) + assert torch.all(torch.abs(log_det_forward + log_det_inverse)) < epsilon) + + +All architecture presets in Torchflows (with a defined forward and inverse pass) are tested to reconstruct inputs and log determinants. +We test reconstruction with inputs taken from a standard Gaussian distribution. +The specified tolerance is either 0.01 or 0.001, though many architectures achieve a lower reconstruction error. + + +Reducing reconstruction error +------------------------------------------ + +We may need an even smaller reconstruction error. +We can start by ensuring the input data is standardized: + +.. code-block:: python + + import torch + from torchflows.architectures import RealNVP + + torch.manual_seed(0) + + batch_shape = (5,) + event_shape = (10,) + x = (torch.randn(size=(*batch_shape, *event_shape)) * 12 + 35) ** 0.5 + x_standardized = (x - x.mean()) / x.std() + + real_nvp = RealNVP(event_shape) + + def print_reconstruction_errors(bijection, inputs): + z, log_det_forward = bijection.forward(inputs) + inputs_reconstructed, log_det_inverse = bijection.inverse(z) + + print(f'Data reconstruction error: {torch.max(torch.abs(inputs - inputs_reconstructed)):.8f}') + print(f'Log determinant error: {torch.max(torch.abs(log_det_forward + log_det_inverse)):.8f}') + + # Test with non-standardized inputs + print_reconstruction_errors(real_nvp, x) + print('-------------------------------------------------------') + print_reconstruction_errors(real_nvp, x_standardized) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..1c79427 --- /dev/null +++ b/environment.yml @@ -0,0 +1,18 @@ +name: torchflows-dev +dependencies: + # Mandatory dependencies for torchflows + - python>=3.7 + - pytorch::pytorch>=2.0.1 + - numpy + - tqdm + - pip + - pytest # testing + - pip: + - torchdiffeq # mandatory for continuous NFs + - Pygments # docs + - Babel # docs + - sphinx>=8.0.1 # docs + - sphinx-rtd-theme # docs + - sphinx-copybutton # docs + - nbsphinx # docs + - sphinx-autoapi # docs diff --git a/test/test_radial_numerical_stability.py b/test/test_radial_numerical_stability.py new file mode 100644 index 0000000..e7b3dd4 --- /dev/null +++ b/test/test_radial_numerical_stability.py @@ -0,0 +1,13 @@ +import torch + +from torchflows import Radial + + +def test_exhaustive(): + torch.manual_seed(0) + event_shape = (1000,) + bijection = Radial(event_shape=event_shape) + z = torch.randn(size=(5000, *event_shape)) ** 2 + x, log_det_inverse = bijection.inverse(z) + assert torch.isfinite(x).all() + assert torch.isfinite(log_det_inverse).all() diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index be33e97..cc89231 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -64,7 +64,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. else: scale = torch.std(x, dim=list(range(n_batch_dims)))[..., None].to(self.value) unconstrained_scale = self.transformer.unconstrain_scale(scale) - self.value.data = torch.concatenate([unconstrained_scale, shift], dim=-1).data + self.value.data = torch.cat([unconstrained_scale, shift], dim=-1).data return super().forward(x, context) diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index 0c4afcd..cccaa49 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -147,9 +147,8 @@ class MaskedAutoregressiveBijection(AutoregressiveBijection): Masked autoregressive bijection class. This bijection is specified with a scalar transformer. - Its conditioner is always MADE, which receives as input a tensor x with x.shape = (*batch_shape, *event_shape). - MADE outputs parameters h for the scalar transformer with - h.shape = (*batch_shape, *event_shape, *parameter_shape_per_element). + Its conditioner is always MADE, which receives as input a tensor x with shape `(*batch_shape, *event_shape)`. + MADE outputs parameters h for the scalar transformer with shape `(*batch_shape, *event_shape, *parameter_shape_per_element)`. The transformer then applies the bijection elementwise. """ diff --git a/torchflows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/torchflows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 07d49e0..0463bc2 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/torchflows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -21,8 +21,10 @@ def inverse_sigmoid(p): class Sigmoid(ScalarTransformer): """ - Applies z = inv_sigmoid(w.T @ sigmoid(a * x + b)) where a > 0, w > 0 and sum(w) = 1. - Note: w, a, b are vectors, so multiplication a * x is broadcast. + Scalar transformer that applies sigmoid-based transformations. + + Applies `z = inv_sigmoid(w.T @ sigmoid(a * x + b))` where `a > 0`, `w > 0` and `sum(w) = 1`. + Note: `w, a, b` are vectors, so multiplication `a * x` is broadcast. """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], @@ -45,7 +47,9 @@ def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: def extract_parameters(self, h: torch.Tensor): """ - h.shape = (batch_size, self.n_parameters) + Convert a parameter vector h into a tuple of parameters to be used in downstream transformations. + + :param torch.Tensor h: parameter vector with shape `(batch_size, n_parameters)`. """ da = h[:, :self.hidden_dim] db = h[:, self.hidden_dim:self.hidden_dim * 2] @@ -60,13 +64,14 @@ def extract_parameters(self, h: torch.Tensor): def forward_1d(self, x, h, eps: float = 1e-6): """ - x.shape = (batch_size,) - h.shape = (batch_size, hidden_size * 3) + Apply forward transformation on input tensor with one event dimension. + + For debug purposes - within this function, the following holds: + `a.shape = (batch_size, hidden_size)`, `b.shape = (batch_size, hidden_size)`, `w.shape = (batch_size, hidden_size)`. - Within the function: - a.shape = (batch_size, hidden_size) - b.shape = (batch_size, hidden_size) - w.shape = (batch_size, hidden_size) + :param torch.Tensor x: input tensor with shape `(batch_size,)`. + :param torch.Tensor h: parameter vector with shape `(batch_size, 3 * hidden_size)`. + :param float eps: small positive scalar. """ a, b, w, log_w = self.extract_parameters(h) c = a * x[:, None] + b # (batch_size, n_hidden) diff --git a/torchflows/bijections/finite/autoregressive/transformers/integration/base.py b/torchflows/bijections/finite/autoregressive/transformers/integration/base.py index 8db6b0c..59450e4 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/integration/base.py +++ b/torchflows/bijections/finite/autoregressive/transformers/integration/base.py @@ -9,9 +9,17 @@ class Integration(ScalarTransformer): + """ + Base integration transformer class. + """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], bound: float = 100.0, eps: float = 1e-6): """ - :param bound: specifies the initial interval [-bound, bound] where numerical inversion is performed. + Integration transformer constructor. + + :param event_shape: shape of the input tensor. + :param bound: specifies the initial interval `[-bound, bound]` where numerical inversion is performed. + :param eps: small scalar value for transformer inversion. """ super().__init__(event_shape) self.bound = bound @@ -33,8 +41,10 @@ def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]): def forward_1d(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - x.shape = (n,) - h.shape = (n, n_parameters) + Apply forward pass on input tensor with one event dimension. + + :param torch.Tensor x: input tensor with shape `(n,)`. + :param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`. """ params = self.compute_parameters(h) return self.base_forward_1d(x, params) @@ -51,8 +61,10 @@ def inverse_1d_without_log_det(self, z: torch.Tensor, params: List[torch.Tensor] def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - z.shape = (n,) - h.shape = (n, n_parameters) + Apply inverse pass on input tensor with one event dimension. + + :param torch.Tensor x: input tensor with shape `(n,)`. + :param torch.Tensor h: parameter tensor with shape `(n, n_parameters)`. """ params = self.compute_parameters(h) x = self.inverse_without_log_det(z, params) @@ -60,8 +72,10 @@ def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, to def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - x.shape = (*batch_shape, *event_shape) - h.shape = (*batch_shape, *parameter_shape) + Apply forward pass. + + :param torch.Tensor x: input tensor with shape `(*batch_shape, *event_shape)`. + :param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`. """ z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element)) z = z_flat.view_as(x) @@ -70,6 +84,12 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return z, log_det def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply inverse pass. + + :param torch.Tensor z: input tensor with shape `(*batch_shape, *event_shape)`. + :param torch.Tensor h: parameter tensor with shape `(*batch_shape, *parameter_shape)`. + """ x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters_per_element)) x = x_flat.view_as(z) batch_shape = get_batch_shape(z, self.event_shape) diff --git a/torchflows/bijections/finite/autoregressive/util.py b/torchflows/bijections/finite/autoregressive/util.py index 42b7108..74c4e79 100644 --- a/torchflows/bijections/finite/autoregressive/util.py +++ b/torchflows/bijections/finite/autoregressive/util.py @@ -9,14 +9,23 @@ class GaussLegendre(torch.autograd.Function): + """ + Autograd function that computes the Gauss-Legendre quadrature. + """ @staticmethod def forward(ctx, f, a: torch.Tensor, b: torch.Tensor, n: int, *h: List[torch.Tensor]) -> torch.Tensor: + """ + Forward autograd map. + """ ctx.f, ctx.n = f, n ctx.save_for_backward(a, b, *h) return GaussLegendre.quadrature(f, a, b, n, h) @staticmethod def backward(ctx, grad_area: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: + """ + Inverse autograd map. + """ f, n = ctx.f, ctx.n a, b, *h = ctx.saved_tensors @@ -62,4 +71,7 @@ def nodes(n: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: def gauss_legendre(f, a, b, n, h): + """ + Compute the Gauss-Legendre quadrature. + """ return GaussLegendre.apply(f, a, b, n, *h) diff --git a/torchflows/bijections/finite/multiscale/architectures.py b/torchflows/bijections/finite/multiscale/architectures.py index f9d6d88..ed32624 100644 --- a/torchflows/bijections/finite/multiscale/architectures.py +++ b/torchflows/bijections/finite/multiscale/architectures.py @@ -4,7 +4,7 @@ from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Shift, Affine from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic -from torchflows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational +from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import ( DeepSigmoid, DeepDenseSigmoid, diff --git a/torchflows/bijections/finite/multiscale/base.py b/torchflows/bijections/finite/multiscale/base.py index 40cb68a..ceaa423 100644 --- a/torchflows/bijections/finite/multiscale/base.py +++ b/torchflows/bijections/finite/multiscale/base.py @@ -145,7 +145,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. assert height % 2 == 0 assert width % 2 == 0 - out = torch.concatenate([ + out = torch.cat([ x[..., ::2, ::2], x[..., ::2, 1::2], x[..., 1::2, ::2], diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 47eb1c3..4351ea4 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -17,44 +17,34 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): self.unconstrained_alpha = nn.Parameter(torch.randn(size=())) self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,))) + self.eps = 1e-6 + @property def alpha(self): return softplus(self.unconstrained_alpha) - def h(self, z): - batch_shape = z.shape[:-1] - z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - r = torch.abs(z - z0) - return 1 / (self.alpha + r) - - def h_deriv(self, z): - batch_shape = z.shape[:-1] - z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - sign = (-1.0) ** torch.less(z, z0).float() - return -(self.h(z) ** 2) * sign * z - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + # Flatten event batch_shape = get_batch_shape(z, self.event_shape) z = z.view(*batch_shape, self.n_dim) + + # Compute auxiliary variables z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) + r = torch.sqrt(torch.square(z - z0)) + h = 1 / (self.alpha + r + self.eps) # Compute transformed point - x = z + self.beta * self.h(z) * (z - z0) + x = z + self.beta * h * (z - z0) # Compute determinant of the Jacobian - h_val = self.h(z) - r = torch.abs(z - z0) - beta_times_h_val = self.beta * h_val - # det = (1 + self.beta * h_val) ** (self.n_dim - 1) * (1 + self.beta * h_val + self.h_deriv(z) * r) - # log_det = torch.log(torch.abs(det)) - # log_det = (self.n_dim - 1) * torch.log1p(beta_times_h_val) + torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - log_det = torch.abs(torch.add( - (self.n_dim - 1) * torch.log1p(beta_times_h_val), - torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - )).sum(dim=-1) - x = x.view(*batch_shape, *self.event_shape) + log_det = torch.add( + torch.log1p(self.alpha * self.beta / h ** 2), + torch.log1p(self.beta / h) * (self.n_dim - 1) + ).sum(dim=-1) + # Unflatten event + x = x.view(*batch_shape, *self.event_shape) return x, log_det diff --git a/torchflows/bijections/numerical_inversion.py b/torchflows/bijections/numerical_inversion.py index 9a642c8..add4baf 100644 --- a/torchflows/bijections/numerical_inversion.py +++ b/torchflows/bijections/numerical_inversion.py @@ -3,8 +3,14 @@ class Bisection(torch.autograd.Function): + """ + Autograd bisection function. + """ @staticmethod def forward(ctx: Any, f, y, a: torch.Tensor, b: torch.Tensor, n: int, h: List[torch.Tensor]) -> torch.Tensor: + """ + Forward method for the autograd function. + """ ctx.f = f ctx.save_for_backward(*h) for _ in range(n): @@ -17,6 +23,9 @@ def forward(ctx: Any, f, y, a: torch.Tensor, b: torch.Tensor, n: int, h: List[to @staticmethod def backward(ctx: Any, grad_x: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + """ + Backward method for the autograd function. + """ f, x = ctx.f, ctx.x h = ctx.saved_tensors with torch.enable_grad(): @@ -33,6 +42,9 @@ def backward(ctx: Any, grad_x: torch.Tensor) -> Tuple[Union[torch.Tensor, None], def bisection(f, y, a, b, n, h): + """ + Apply bisection with autograd support. + """ return Bisection.apply(f, y, a.to(y), b.to(y), n, h) @@ -43,12 +55,16 @@ def bisection_no_gradient(f: callable, n_iterations: int = 500, atol: float = 1e-9): """ - Find x that satisfies f(x) = y. - We assume x.shape == y.shape. - f is applied elementwise. + Apply bisection without autograd support. + + Explanation: find x that satisfies f(x) = y. We assume x.shape == y.shape. f is applied elementwise. - a: lower bound. - b: upper bound. + :param f: function that takes as input a tensor and produces as output z. + :param y: value to match to z. + :param a: lower bound for bisection search. + :param b: upper bound for bisection search. + :param n_iterations: number of bisection iterations. + :param atol: absolute tolerance. """ if a is None: