Skip to content

Commit

Permalink
Merge pull request #19 from davidnabergoj/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
davidnabergoj authored Aug 14, 2024
2 parents 2ca23bb + 73c8dde commit bb7046d
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 27 deletions.
14 changes: 13 additions & 1 deletion docs/source/api/base_distributions.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
Base distribution objects
==========================
==========================

.. autoclass:: torchflows.base_distributions.gaussian.DiagonalGaussian
:members: __init__

.. autoclass:: torchflows.base_distributions.gaussian.DenseGaussian
:members: __init__

.. autoclass:: torchflows.base_distributions.mixture.DiagonalGaussianMixture
:members: __init__

.. autoclass:: torchflows.base_distributions.mixture.DenseGaussianMixture
:members: __init__
51 changes: 50 additions & 1 deletion docs/source/guides/choosing_base_distributions.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,51 @@
Choosing a base distribution
==============================
==============================

We may replace the default standard Gaussian distribution with any torch distribution that is also a module.
Some custom distributions are already implemented.
We show an example for a diagonal Gaussian base distribution with mean 3 and standard deviation 2.

.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian
torch.manual_seed(0)
event_shape = (10,)
base_distribution = DiagonalGaussian(
loc=torch.full(size=event_shape, fill_value=3.0),
scale=torch.full(size=event_shape, fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)
x_new = flow.sample((10,))
Nontrivial event shapes
------------------------

When the event has more than one axis, the base distribution must deal with flattened data. We show an example below.

.. note::

The requirement to work with flattened data may change in the future.


.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian
torch.manual_seed(0)
event_shape = (2, 3, 5)
event_size = int(torch.prod(torch.as_tensor(event_shape)))
base_distribution = DiagonalGaussian(
loc=torch.full(size=(event_size,), fill_value=3.0),
scale=torch.full(size=(event_size,), fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)
x_new = flow.sample((10,))
18 changes: 18 additions & 0 deletions docs/source/guides/cuda.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Using CUDA
===========

Torchflows models are torch modules and thus seamlessly support CUDA (and other devices).
When using the *fit* method, training data is automatically transferred onto the flow device.

.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
torch.manual_seed(0)
event_shape = (10,)
x_train = torch.randn(size=(1000, *event_shape))
flow = Flow(RealNVP(event_shape)).cuda()
flow.fit(x_train, show_progress=True)
24 changes: 22 additions & 2 deletions docs/source/guides/event_shapes.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
Complex event shapes
======================
Event shapes
======================

Torchflows supports modeling tensors with arbitrary shapes. For example, we can model events with shape `(2, 3, 5)` as follows:

.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
torch.manual_seed(0)
event_shape = (2, 3, 5)
n_data = 1000
x_train = torch.randn(size=(n_data, *event_shape))
print(x_train.shape) # (1000, 2, 3, 5)
flow = Flow(RealNVP(event_shape))
flow.fit(x_train, show_progress=True)
x_new = flow.sample((500,))
print(x_new.shape) # (500, 2, 3, 5)
22 changes: 21 additions & 1 deletion docs/source/guides/image_modeling.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
Image modeling
==============
==============

When modeling images, we can use specialized multiscale architectures which use convolutional neural network conditioners and specialized coupling schemes.
These architectures expect event shapes to be *(channels, height, width)*.

.. note::
Multiscale architectures are currently undergoing improvements.

.. code-block:: python
import torch
from torchflows.flows import Flow
from torchflows.architectures import MultiscaleRealNVP
image_shape = (3, 28, 28)
n_images = 100
torch.manual_seed(0)
training_images = torch.randn(size=(n_images, *image_shape)) # synthetic data
flow = Flow(MultiscaleRealNVP(image_shape))
flow.fit(training_images, show_progress=True)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Examples
Tutorial
===========

We provide tutorials and notebooks for typical Torchflows use cases.
Expand All @@ -10,3 +10,4 @@ We provide tutorials and notebooks for typical Torchflows use cases.
event_shapes
image_modeling
choosing_base_distributions
cuda
4 changes: 3 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ Guides
.. toctree::

guides/installing
guides/usage
guides/tutorial

API
====

.. toctree::
:maxdepth: 3

api/components
api/architectures
api/multiscale_architectures
Expand Down
8 changes: 2 additions & 6 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torchflows.bijections.finite.autoregressive.architectures import RealNVP


@pytest.mark.skip(reason="Too slow on CI/CD")
def test_real_nvp_log_prob_data_on_cpu():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand All @@ -20,7 +19,6 @@ def test_real_nvp_log_prob_data_on_cpu():
flow.log_prob(x_train)


@pytest.mark.skip(reason="Too slow on CI/CD")
def test_real_nvp_log_prob_data_on_gpu():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand All @@ -36,7 +34,6 @@ def test_real_nvp_log_prob_data_on_gpu():
flow.log_prob(x_train.cuda())


@pytest.mark.skip(reason="Too slow on CI/CD")
def test_real_nvp_fit_data_on_cpu():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand All @@ -49,10 +46,9 @@ def test_real_nvp_fit_data_on_cpu():
x_train = torch.randn(*batch_shape, *event_shape)

flow = Flow(RealNVP(event_shape)).cuda()
flow.fit(x_train)
flow.fit(x_train, n_epochs=3)


@pytest.mark.skip(reason="Too slow on CI/CD")
def test_real_nvp_fit_data_on_gpu():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
Expand All @@ -65,4 +61,4 @@ def test_real_nvp_fit_data_on_gpu():
x_train = torch.randn(*batch_shape, *event_shape)

flow = Flow(RealNVP(event_shape)).cuda()
flow.fit(x_train.cuda())
flow.fit(x_train.cuda(), n_epochs=3)
24 changes: 22 additions & 2 deletions torchflows/base_distributions/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@


class DiagonalGaussian(torch.distributions.Distribution, nn.Module):
"""Diagonal Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
loc: torch.Tensor,
scale: torch.Tensor,
trainable_loc: bool = False,
trainable_scale: bool = False):
super().__init__(event_shape=loc.shape)
"""
DiagonalGaussian constructor.
:param torch.Tensor loc: location vector with shape `(event_size,)`.
:param torch.Tensor scale: scale vector with shape `(event_size,)`.
:param bool trainable_loc: if True, the make the location trainable.
:param bool trainable_scale: if True, the make the scale trainable.
"""
super().__init__(event_shape=loc.shape, validate_args=False)
self.log_2_pi = math.log(2 * math.pi)
if trainable_loc:
self.register_parameter('loc', nn.Parameter(loc))
Expand Down Expand Up @@ -44,11 +54,21 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:


class DenseGaussian(torch.distributions.Distribution, nn.Module):
"""
Dense Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
loc: torch.Tensor,
cov: torch.Tensor,
trainable_loc: bool = False):
super().__init__(event_shape=loc.shape)
"""
DenseGaussian constructor.
:param torch.Tensor loc: location vector with shape `(event_size,)`.
:param torch.Tensor cov: covariance matrix with shape `(event_size, event_size)`.
:param bool trainable_loc: if True, the make the location trainable.
"""
super().__init__(event_shape=loc.shape, validate_args=False)
event_size = int(torch.prod(torch.as_tensor(self.event_shape)))
if cov.shape != (event_size, event_size):
raise ValueError("Incorrect covariance matrix shape")
Expand Down
32 changes: 31 additions & 1 deletion torchflows/base_distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@


class Mixture(torch.distributions.Distribution, nn.Module):
"""
Base mixture distribution class. Extends torch.distributions.Distribution and torch.nn.Module.
"""
def __init__(self,
components: List[torch.distributions.Distribution],
weights: torch.Tensor = None):
"""
Mixture constructor.
:param List[torch.distributions.Distribution] components: list of distribution components.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
"""
if weights is None:
weights = torch.ones(len(components)) / len(components)
super().__init__(event_shape=components[0].event_shape)
super().__init__(event_shape=components[0].event_shape, validate_args=False)
self.register_buffer('log_weights', torch.log(weights))
self.components = components
self.categorical = torch.distributions.Categorical(probs=weights)
Expand All @@ -37,12 +46,25 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:


class DiagonalGaussianMixture(Mixture):
"""
Mixture distribution of diagonal Gaussians. Extends Mixture.
"""

def __init__(self,
locs: torch.Tensor,
scales: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False,
trainable_scales: bool = False):
"""
DiagonalGaussianMixture constructor.
:param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`.
:param torch.Tensor scales: tensor of scales with shape `(n_components, event_size)`.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
:param bool trainable_locs: if True, make locations trainable.
:param bool trainable_scales: if True, make scales trainable.
"""
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
Expand All @@ -56,6 +78,14 @@ def __init__(self,
covs: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False):
"""
DenseGaussianMixture constructor. Extends Mixture.
:param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`.
:param torch.Tensor covs: tensor of covariance matrices with shape `(n_components, event_size, event_size)`.
:param torch.Tensor weights: tensor of weights with shape `(n_components,)`.
:param bool trainable_locs: if True, make locations trainable.
"""
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
Expand Down
4 changes: 2 additions & 2 deletions torchflows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def __init__(self,
self.layers = nn.ModuleList(layers)

def forward(self, x: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
log_det = torch.zeros(size=get_batch_shape(x, event_shape=self.event_shape), device=x.device)
log_det = torch.zeros(size=get_batch_shape(x, event_shape=self.event_shape)).to(x)
for layer in self.layers:
x, log_det_layer = layer(x, context=context)
log_det += log_det_layer
z = x
return z, log_det

def inverse(self, z: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
log_det = torch.zeros(size=get_batch_shape(z, event_shape=self.event_shape), device=z.device)
log_det = torch.zeros(size=get_batch_shape(z, event_shape=self.event_shape)).to(z)
for layer in self.layers[::-1]:
z, log_det_layer = layer.inverse(z, context=context)
log_det += log_det_layer
Expand Down
14 changes: 5 additions & 9 deletions torchflows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm
from torchflows.bijections.base import Bijection
from torchflows.utils import flatten_event, unflatten_event, create_data_loader
from torchflows.base_distributions.gaussian import DiagonalGaussian


class BaseFlow(nn.Module):
Expand All @@ -26,16 +27,13 @@ def __init__(self,
self.event_size = int(torch.prod(torch.as_tensor(event_shape)))

if base_distribution == 'standard_normal':
self.base = torch.distributions.MultivariateNormal(
loc=torch.zeros(self.event_size),
covariance_matrix=torch.eye(self.event_size)
)
self.base = DiagonalGaussian(loc=torch.zeros(self.event_size), scale=torch.ones(self.event_size))
elif isinstance(base_distribution, torch.distributions.Distribution):
self.base = base_distribution
else:
raise ValueError(f'Invalid base distribution: {base_distribution}')

self.device_buffer = torch.empty(size=())
self.register_buffer('device_buffer', torch.empty(size=()))

def get_device(self):
"""Returns the torch device for this object.
Expand Down Expand Up @@ -257,10 +255,8 @@ def variational_fit(self,
show_progress: bool = False):
"""Train the normalizing flow to fit a target log probability.
Stochastic variational inference lets us train a distribution using the unnormalized target log density
instead of a fixed dataset.
Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details
(https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1).
Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset.
Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details (https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1).
:param callable target_log_prob: function that computes the unnormalized target log density for a batch of
points. Receives input batch with shape `(*batch_shape, *event_shape)` and outputs batch with
Expand Down

0 comments on commit bb7046d

Please sign in to comment.