Skip to content

Commit

Permalink
Update docs for base distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 14, 2024
1 parent fb067f2 commit 27a58e2
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 2 deletions.
14 changes: 13 additions & 1 deletion docs/source/api/base_distributions.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
Base distribution objects
==========================
==========================

.. autoclass:: torchflows.base_distributions.gaussian.DiagonalGaussian
:members: __init__

.. autoclass:: torchflows.base_distributions.gaussian.DenseGaussian
:members: __init__

.. autoclass:: torchflows.base_distributions.mixture.DiagonalGaussianMixture
:members: __init__

.. autoclass:: torchflows.base_distributions.mixture.DenseGaussianMixture
:members: __init__
51 changes: 50 additions & 1 deletion docs/source/guides/choosing_base_distributions.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,51 @@
Choosing a base distribution
==============================
==============================

We may replace the default standard Gaussian distribution with any torch distribution that is also a module.
Some custom distributions are already implemented.
We show an example for a diagonal Gaussian base distribution with mean 3 and standard deviation 2.

.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian
torch.manual_seed(0)
event_shape = (10,)
base_distribution = DiagonalGaussian(
loc=torch.full(size=event_shape, fill_value=3.0),
scale=torch.full(size=event_shape, fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)
x_new = flow.sample((10,))
Nontrivial event shapes
------------------------

When the event has more than one axis, the base distribution must deal with flattened data. We show an example below.

.. note::

The requirement to work with flattened data may change in the future.


.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian
torch.manual_seed(0)
event_shape = (2, 3, 5)
event_size = int(torch.prod(torch.as_tensor(event_shape)))
base_distribution = DiagonalGaussian(
loc=torch.full(size=(event_size,), fill_value=3.0),
scale=torch.full(size=(event_size,), fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)
x_new = flow.sample((10,))
20 changes: 20 additions & 0 deletions torchflows/base_distributions/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@


class DiagonalGaussian(torch.distributions.Distribution, nn.Module):
"""Diagonal Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
loc: torch.Tensor,
scale: torch.Tensor,
trainable_loc: bool = False,
trainable_scale: bool = False):
"""
DiagonalGaussian constructor.
:param torch.Tensor loc: location vector with shape `(event_size,)`.
:param torch.Tensor scale: scale vector with shape `(event_size,)`.
:param bool trainable_loc: if True, the make the location trainable.
:param bool trainable_scale: if True, the make the scale trainable.
"""
super().__init__(event_shape=loc.shape, validate_args=False)
self.log_2_pi = math.log(2 * math.pi)
if trainable_loc:
Expand Down Expand Up @@ -44,10 +54,20 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:


class DenseGaussian(torch.distributions.Distribution, nn.Module):
"""
Dense Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
loc: torch.Tensor,
cov: torch.Tensor,
trainable_loc: bool = False):
"""
DenseGaussian constructor.
:param torch.Tensor loc: location vector with shape `(event_size,)`.
:param torch.Tensor cov: covariance matrix with shape `(event_size, event_size)`.
:param bool trainable_loc: if True, the make the location trainable.
"""
super().__init__(event_shape=loc.shape, validate_args=False)
event_size = int(torch.prod(torch.as_tensor(self.event_shape)))
if cov.shape != (event_size, event_size):
Expand Down
30 changes: 30 additions & 0 deletions torchflows/base_distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,18 @@


class Mixture(torch.distributions.Distribution, nn.Module):
"""
Base mixture distribution class. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
components: List[torch.distributions.Distribution],
weights: torch.Tensor = None):
"""
Mixture constructor.
:param List[torch.distributions.Distribution] components: list of distribution components.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
"""
if weights is None:
weights = torch.ones(len(components)) / len(components)
super().__init__(event_shape=components[0].event_shape, validate_args=False)
Expand All @@ -37,12 +46,25 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:


class DiagonalGaussianMixture(Mixture):
"""
Mixture distribution of diagonal Gaussians. Extends Mixture.
"""

def __init__(self,
locs: torch.Tensor,
scales: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False,
trainable_scales: bool = False):
"""
DiagonalGaussianMixture constructor.
:param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`.
:param torch.Tensor scales: tensor of scales with shape `(n_components, event_size)`.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
:param bool trainable_locs: if True, make locations trainable.
:param bool trainable_scales: if True, make scales trainable.
"""
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
Expand All @@ -56,6 +78,14 @@ def __init__(self,
covs: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False):
"""
DenseGaussianMixture constructor. Extends Mixture.
:param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`.
:param torch.Tensor covs: tensor of covariance matrices with shape `(n_components, event_size, event_size)`.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
:param bool trainable_locs: if True, make locations trainable.
"""
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
Expand Down

0 comments on commit 27a58e2

Please sign in to comment.