diff --git a/docs/source/api/base_distributions.rst b/docs/source/api/base_distributions.rst index 2421e93..174ba1e 100644 --- a/docs/source/api/base_distributions.rst +++ b/docs/source/api/base_distributions.rst @@ -1,2 +1,14 @@ Base distribution objects -========================== \ No newline at end of file +========================== + +.. autoclass:: torchflows.base_distributions.gaussian.DiagonalGaussian + :members: __init__ + +.. autoclass:: torchflows.base_distributions.gaussian.DenseGaussian + :members: __init__ + +.. autoclass:: torchflows.base_distributions.mixture.DiagonalGaussianMixture + :members: __init__ + +.. autoclass:: torchflows.base_distributions.mixture.DenseGaussianMixture + :members: __init__ diff --git a/docs/source/guides/choosing_base_distributions.rst b/docs/source/guides/choosing_base_distributions.rst index 70d7801..bc9e472 100644 --- a/docs/source/guides/choosing_base_distributions.rst +++ b/docs/source/guides/choosing_base_distributions.rst @@ -1,2 +1,51 @@ Choosing a base distribution -============================== \ No newline at end of file +============================== + +We may replace the default standard Gaussian distribution with any torch distribution that is also a module. +Some custom distributions are already implemented. +We show an example for a diagonal Gaussian base distribution with mean 3 and standard deviation 2. + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import RealNVP + from torchflows.base_distributions.gaussian import DiagonalGaussian + + torch.manual_seed(0) + event_shape = (10,) + base_distribution = DiagonalGaussian( + loc=torch.full(size=event_shape, fill_value=3.0), + scale=torch.full(size=event_shape, fill_value=2.0), + ) + flow = Flow(RealNVP(event_shape), base_distribution=base_distribution) + + x_new = flow.sample((10,)) + +Nontrivial event shapes +------------------------ + +When the event has more than one axis, the base distribution must deal with flattened data. We show an example below. + +.. note:: + + The requirement to work with flattened data may change in the future. + + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import RealNVP + from torchflows.base_distributions.gaussian import DiagonalGaussian + + torch.manual_seed(0) + event_shape = (2, 3, 5) + event_size = int(torch.prod(torch.as_tensor(event_shape))) + base_distribution = DiagonalGaussian( + loc=torch.full(size=(event_size,), fill_value=3.0), + scale=torch.full(size=(event_size,), fill_value=2.0), + ) + flow = Flow(RealNVP(event_shape), base_distribution=base_distribution) + + x_new = flow.sample((10,)) diff --git a/docs/source/guides/cuda.rst b/docs/source/guides/cuda.rst new file mode 100644 index 0000000..84e3009 --- /dev/null +++ b/docs/source/guides/cuda.rst @@ -0,0 +1,18 @@ +Using CUDA +=========== + +Torchflows models are torch modules and thus seamlessly support CUDA (and other devices). +When using the *fit* method, training data is automatically transferred onto the flow device. + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import RealNVP + + torch.manual_seed(0) + event_shape = (10,) + x_train = torch.randn(size=(1000, *event_shape)) + + flow = Flow(RealNVP(event_shape)).cuda() + flow.fit(x_train, show_progress=True) \ No newline at end of file diff --git a/docs/source/guides/event_shapes.rst b/docs/source/guides/event_shapes.rst index 28e0f63..e2a62d3 100644 --- a/docs/source/guides/event_shapes.rst +++ b/docs/source/guides/event_shapes.rst @@ -1,2 +1,22 @@ -Complex event shapes -====================== \ No newline at end of file +Event shapes +====================== + +Torchflows supports modeling tensors with arbitrary shapes. For example, we can model events with shape `(2, 3, 5)` as follows: + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import RealNVP + + torch.manual_seed(0) + event_shape = (2, 3, 5) + n_data = 1000 + x_train = torch.randn(size=(n_data, *event_shape)) + print(x_train.shape) # (1000, 2, 3, 5) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, show_progress=True) + + x_new = flow.sample((500,)) + print(x_new.shape) # (500, 2, 3, 5) \ No newline at end of file diff --git a/docs/source/guides/image_modeling.rst b/docs/source/guides/image_modeling.rst index 4153480..757365e 100644 --- a/docs/source/guides/image_modeling.rst +++ b/docs/source/guides/image_modeling.rst @@ -1,2 +1,22 @@ Image modeling -============== \ No newline at end of file +============== + +When modeling images, we can use specialized multiscale architectures which use convolutional neural network conditioners and specialized coupling schemes. +These architectures expect event shapes to be *(channels, height, width)*. + +.. note:: + Multiscale architectures are currently undergoing improvements. + +.. code-block:: python + + import torch + from torchflows.flows import Flow + from torchflows.architectures import MultiscaleRealNVP + + image_shape = (3, 28, 28) + n_images = 100 + + torch.manual_seed(0) + training_images = torch.randn(size=(n_images, *image_shape)) # synthetic data + flow = Flow(MultiscaleRealNVP(image_shape)) + flow.fit(training_images, show_progress=True) \ No newline at end of file diff --git a/docs/source/guides/usage.rst b/docs/source/guides/tutorial.rst similarity index 92% rename from docs/source/guides/usage.rst rename to docs/source/guides/tutorial.rst index 9b91ff6..f9f63b1 100644 --- a/docs/source/guides/usage.rst +++ b/docs/source/guides/tutorial.rst @@ -1,4 +1,4 @@ -Examples +Tutorial =========== We provide tutorials and notebooks for typical Torchflows use cases. @@ -10,3 +10,4 @@ We provide tutorials and notebooks for typical Torchflows use cases. event_shapes image_modeling choosing_base_distributions + cuda diff --git a/docs/source/index.rst b/docs/source/index.rst index 606b626..d2f9b46 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,12 +31,14 @@ Guides .. toctree:: guides/installing - guides/usage + guides/tutorial API ==== .. toctree:: + :maxdepth: 3 + api/components api/architectures api/multiscale_architectures diff --git a/test/test_cuda.py b/test/test_cuda.py index e449291..888a08f 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4,7 +4,6 @@ from torchflows.bijections.finite.autoregressive.architectures import RealNVP -@pytest.mark.skip(reason="Too slow on CI/CD") def test_real_nvp_log_prob_data_on_cpu(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -20,7 +19,6 @@ def test_real_nvp_log_prob_data_on_cpu(): flow.log_prob(x_train) -@pytest.mark.skip(reason="Too slow on CI/CD") def test_real_nvp_log_prob_data_on_gpu(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -36,7 +34,6 @@ def test_real_nvp_log_prob_data_on_gpu(): flow.log_prob(x_train.cuda()) -@pytest.mark.skip(reason="Too slow on CI/CD") def test_real_nvp_fit_data_on_cpu(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -49,10 +46,9 @@ def test_real_nvp_fit_data_on_cpu(): x_train = torch.randn(*batch_shape, *event_shape) flow = Flow(RealNVP(event_shape)).cuda() - flow.fit(x_train) + flow.fit(x_train, n_epochs=3) -@pytest.mark.skip(reason="Too slow on CI/CD") def test_real_nvp_fit_data_on_gpu(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -65,4 +61,4 @@ def test_real_nvp_fit_data_on_gpu(): x_train = torch.randn(*batch_shape, *event_shape) flow = Flow(RealNVP(event_shape)).cuda() - flow.fit(x_train.cuda()) + flow.fit(x_train.cuda(), n_epochs=3) diff --git a/torchflows/base_distributions/gaussian.py b/torchflows/base_distributions/gaussian.py index bc37342..9ce1764 100644 --- a/torchflows/base_distributions/gaussian.py +++ b/torchflows/base_distributions/gaussian.py @@ -6,12 +6,22 @@ class DiagonalGaussian(torch.distributions.Distribution, nn.Module): + """Diagonal Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module. + """ def __init__(self, loc: torch.Tensor, scale: torch.Tensor, trainable_loc: bool = False, trainable_scale: bool = False): - super().__init__(event_shape=loc.shape) + """ + DiagonalGaussian constructor. + + :param torch.Tensor loc: location vector with shape `(event_size,)`. + :param torch.Tensor scale: scale vector with shape `(event_size,)`. + :param bool trainable_loc: if True, the make the location trainable. + :param bool trainable_scale: if True, the make the scale trainable. + """ + super().__init__(event_shape=loc.shape, validate_args=False) self.log_2_pi = math.log(2 * math.pi) if trainable_loc: self.register_parameter('loc', nn.Parameter(loc)) @@ -44,11 +54,21 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: class DenseGaussian(torch.distributions.Distribution, nn.Module): + """ + Dense Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module. + """ def __init__(self, loc: torch.Tensor, cov: torch.Tensor, trainable_loc: bool = False): - super().__init__(event_shape=loc.shape) + """ + DenseGaussian constructor. + + :param torch.Tensor loc: location vector with shape `(event_size,)`. + :param torch.Tensor cov: covariance matrix with shape `(event_size, event_size)`. + :param bool trainable_loc: if True, the make the location trainable. + """ + super().__init__(event_shape=loc.shape, validate_args=False) event_size = int(torch.prod(torch.as_tensor(self.event_shape))) if cov.shape != (event_size, event_size): raise ValueError("Incorrect covariance matrix shape") diff --git a/torchflows/base_distributions/mixture.py b/torchflows/base_distributions/mixture.py index 3c35358..17f3a0f 100644 --- a/torchflows/base_distributions/mixture.py +++ b/torchflows/base_distributions/mixture.py @@ -8,12 +8,21 @@ class Mixture(torch.distributions.Distribution, nn.Module): + """ + Base mixture distribution class. Extends torch.distributions.Distribution and torch.nn.Module. + """ def __init__(self, components: List[torch.distributions.Distribution], weights: torch.Tensor = None): + """ + Mixture constructor. + + :param List[torch.distributions.Distribution] components: list of distribution components. + :param torch.Tensor weights: tensor of weights with shape `(n_components,)`. + """ if weights is None: weights = torch.ones(len(components)) / len(components) - super().__init__(event_shape=components[0].event_shape) + super().__init__(event_shape=components[0].event_shape, validate_args=False) self.register_buffer('log_weights', torch.log(weights)) self.components = components self.categorical = torch.distributions.Categorical(probs=weights) @@ -37,12 +46,25 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: class DiagonalGaussianMixture(Mixture): + """ + Mixture distribution of diagonal Gaussians. Extends Mixture. + """ + def __init__(self, locs: torch.Tensor, scales: torch.Tensor, weights: torch.Tensor = None, trainable_locs: bool = False, trainable_scales: bool = False): + """ + DiagonalGaussianMixture constructor. + + :param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`. + :param torch.Tensor scales: tensor of scales with shape `(n_components, event_size)`. + :param torch.Tensor weights: tensor of weights with shape `(n_components,)`. + :param bool trainable_locs: if True, make locations trainable. + :param bool trainable_scales: if True, make scales trainable. + """ n_components, *event_shape = locs.shape components = [] for i in range(n_components): @@ -56,6 +78,14 @@ def __init__(self, covs: torch.Tensor, weights: torch.Tensor = None, trainable_locs: bool = False): + """ + DenseGaussianMixture constructor. Extends Mixture. + + :param torch.Tensor locs: tensor of locations with shape `(n_components, event_size)`. + :param torch.Tensor covs: tensor of covariance matrices with shape `(n_components, event_size, event_size)`. + :param torch.Tensor weights: tensor of weights with shape `(n_components,)`. + :param bool trainable_locs: if True, make locations trainable. + """ n_components, *event_shape = locs.shape components = [] for i in range(n_components): diff --git a/torchflows/bijections/base.py b/torchflows/bijections/base.py index 218de6f..23012fb 100644 --- a/torchflows/bijections/base.py +++ b/torchflows/bijections/base.py @@ -122,7 +122,7 @@ def __init__(self, self.layers = nn.ModuleList(layers) def forward(self, x: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - log_det = torch.zeros(size=get_batch_shape(x, event_shape=self.event_shape), device=x.device) + log_det = torch.zeros(size=get_batch_shape(x, event_shape=self.event_shape)).to(x) for layer in self.layers: x, log_det_layer = layer(x, context=context) log_det += log_det_layer @@ -130,7 +130,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tu return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - log_det = torch.zeros(size=get_batch_shape(z, event_shape=self.event_shape), device=z.device) + log_det = torch.zeros(size=get_batch_shape(z, event_shape=self.event_shape)).to(z) for layer in self.layers[::-1]: z, log_det_layer = layer.inverse(z, context=context) log_det += log_det_layer diff --git a/torchflows/flows.py b/torchflows/flows.py index 39aa4df..8cc8362 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -7,6 +7,7 @@ from tqdm import tqdm from torchflows.bijections.base import Bijection from torchflows.utils import flatten_event, unflatten_event, create_data_loader +from torchflows.base_distributions.gaussian import DiagonalGaussian class BaseFlow(nn.Module): @@ -26,16 +27,13 @@ def __init__(self, self.event_size = int(torch.prod(torch.as_tensor(event_shape))) if base_distribution == 'standard_normal': - self.base = torch.distributions.MultivariateNormal( - loc=torch.zeros(self.event_size), - covariance_matrix=torch.eye(self.event_size) - ) + self.base = DiagonalGaussian(loc=torch.zeros(self.event_size), scale=torch.ones(self.event_size)) elif isinstance(base_distribution, torch.distributions.Distribution): self.base = base_distribution else: raise ValueError(f'Invalid base distribution: {base_distribution}') - self.device_buffer = torch.empty(size=()) + self.register_buffer('device_buffer', torch.empty(size=())) def get_device(self): """Returns the torch device for this object. @@ -257,10 +255,8 @@ def variational_fit(self, show_progress: bool = False): """Train the normalizing flow to fit a target log probability. - Stochastic variational inference lets us train a distribution using the unnormalized target log density - instead of a fixed dataset. - Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details - (https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1). + Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. + Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details (https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1). :param callable target_log_prob: function that computes the unnormalized target log density for a batch of points. Receives input batch with shape `(*batch_shape, *event_shape)` and outputs batch with