diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/README.md b/README.md index d2a2ee7..e3931fd 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@ # Normalizing flows in PyTorch This package implements normalizing flows and their building blocks. -The package is meant for researchers, enabling: +It allows: -* easy use of normalizing flows as generative models or density estimators in various applications; -* systematic comparisons of normalizing flows or their building blocks; -* simple implementation of new normalizing flows which belong to either the autoregressive, residual, or continuous - families; +* easy use of normalizing flows as trainable distributions; +* easy implementation of new normalizing flows. Example use: ```python import torch -from normalizing_flows import RealNVP, Flow +from normalizing_flows import Flow +from normalizing_flows.architectures import RealNVP + torch.manual_seed(0) @@ -53,48 +53,51 @@ We support Python versions 3.7 and upwards. ## Brief background -A normalizing flow (NF) is a flexible distribution, defined as a bijective transformation of a simple statistical -distribution. -The simple distribution is typically a standard Gaussian. -The transformation is typically an invertible neural network that can make the NF arbitrarily complex. -Training a NF using a dataset means optimizing the parameters of transformation to make the dataset likely under the NF. +A normalizing flow (NF) is a flexible trainable distribution. +It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian. +The bijection is typically an invertible neural network. +Training a NF using a dataset means optimizing the bijection's parameters to make the dataset likely under the NF. We can use a NF to compute the probability of a data point or to independently sample data from the process that generated our dataset. -A NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as: -$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|$$ - -## Implemented architectures - -We implement the following NF transformations: - -| Bijection | Inverse | Log determinant | Inverse implemented | -|---------------------------------------------------------------------|:-----------:|:-----------------------:|:-------------------:| -| [NICE](http://arxiv.org/abs/1410.8516) | Exact | Exact | Yes | -| [Real NVP](http://arxiv.org/abs/1605.08803) | Exact | Exact | Yes | -| [MAF](http://arxiv.org/abs/1705.07057) | Exact | Exact | Yes | -| [IAF](http://arxiv.org/abs/1606.04934) | Exact | Exact | Yes | -| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Exact | Exact | Yes | -| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Exact | Exact | Yes | -| [NAF](http://arxiv.org/abs/1804.00779) | | | | -| [Block NAF](http://arxiv.org/abs/1904.04676) | | | | -| [UMNN](http://arxiv.org/abs/1908.05164) | Approximate | Exact | No | -| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Approximate | Exact | No | -| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Approximate | Exact | No | -| [Sylvester](http://arxiv.org/abs/1803.05649) | Approximate | Exact | No | -| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Approximate | Biased approximation | Yes | -| [ResFlow](http://arxiv.org/abs/1906.02735) | Approximate | Unbiased approximation | Yes | -| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Approximate | Exact (if single layer) | Yes | -| [FFJORD](http://arxiv.org/abs/1810.01367) | Approximate | Approximate | Yes | -| [RNODE](http://arxiv.org/abs/2002.02798) | Approximate | Approximate | Yes | -| [DDNF](http://arxiv.org/abs/1810.03256) | Approximate | Approximate | Yes | -| [OT flow](http://arxiv.org/abs/2006.00104) | Approximate | Exact | Yes | - -Note: inverse approximations can be made arbitrarily accurate with stricter convergence conditions. -Architectures without an implemented inverse support either sampling or density estimation, but not both at once. -Such architectures are unsuitable for downstream tasks which require both functionalities. - -We also implement simple bijections that can be used in the same manner: +The density of a NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as: +$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|.$$ +Sampling from a NF means sampling from the simple distribution and transforming the sample using the bijection. + +## Supported architectures + +We list supported NF architectures below. +We classify architectures as either autoregressive, residual, or continuous; as defined +by [Papamakarios et al. (2021)](https://arxiv.org/abs/1912.02762). +Exact architectures do not use numerical approximations to generate data or compute the log density. + +| Architecture | Bijection type | Exact | Two-way | +|--------------------------------------------------------------------------|:--------------------------:|:-------:|:-------:| +| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive | ✔ | ✔ | +| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive | ✔ | ✔ | +| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive | ✔ | ✔ | +| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive | ✔ | ✔ | +| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive | ✔ | ✔ | +| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive | ✔ | ✔ | +| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive | ✗ | ✔ | +| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive | ✗ | ✔ | +| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual | ✗ | ✗ | +| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual | ✗ | ✗ | +| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual | ✗ | ✗ | +| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual | ✗ | ✔* | +| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual | ✗ | ✔* | +| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual | ✗ | ✔* | +| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous | ✗ | ✔* | +| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous | ✗ | ✔* | +| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous | ✗ | ✔* | +| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous | ✗ | ✔ | + +Two-way architectures support both sampling and density estimation. +Two-way architectures marked with an asterisk (*) support both, but use a numerical approximation to sample or estimate +density. +One-way architectures support either sampling or density estimation, but not both at once. + +We also support simple bijections (all exact and two-way): * Permutation * Elementwise translation (shift vector) @@ -102,5 +105,3 @@ We also implement simple bijections that can be used in the same manner: * Rotation (orthogonal matrix) * Triangular matrix * Dense matrix (using the QR or LU decomposition) - -All of these have exact inverses and log determinants. \ No newline at end of file diff --git a/examples/Computing log determinants.md b/examples/Computing log determinants.md index 17e236c..27a37dc 100644 --- a/examples/Computing log determinants.md +++ b/examples/Computing log determinants.md @@ -7,7 +7,7 @@ The code is as follows: ```python import torch from normalizing_flows import Flow -from normalizing_flows.bijections import RealNVP +from normalizing_flows.architectures import RealNVP torch.manual_seed(0) diff --git a/examples/Modifying architectures.md b/examples/Modifying architectures.md index aec3d68..f8794dc 100644 --- a/examples/Modifying architectures.md +++ b/examples/Modifying architectures.md @@ -4,7 +4,7 @@ We give an example on how to modify a bijection's architecture. We use the Masked Autoregressive Flow (MAF) as an example. We can manually set the number of invertible layers as follows: ```python -from normalizing_flows.bijections import MAF +from normalizing_flows.architectures import MAF event_shape = (10,) flow = MAF(event_shape=event_shape, n_layers=5) diff --git a/examples/Training a normalizing flow.md b/examples/Training a normalizing flow.md index 1078bfc..291b965 100644 --- a/examples/Training a normalizing flow.md +++ b/examples/Training a normalizing flow.md @@ -7,7 +7,7 @@ The code is as follows: ```python import torch from normalizing_flows import Flow -from normalizing_flows.bijections import RealNVP +from normalizing_flows.architectures import RealNVP torch.manual_seed(0) diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py new file mode 100644 index 0000000..83a0295 --- /dev/null +++ b/normalizing_flows/architectures.py @@ -0,0 +1,20 @@ +from normalizing_flows.bijections.finite.autoregressive.architectures import ( + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + UMNNMAF +) + +from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection +from normalizing_flows.bijections.continuous.rnode import RNODE +from normalizing_flows.bijections.continuous.ffjord import FFJORD +from normalizing_flows.bijections.continuous.otflow import OTFlow + +from normalizing_flows.bijections.finite.residual.architectures import ResFlow, ProximalResFlow, InvertibleResNet diff --git a/normalizing_flows/bijections/continuous/base.py b/normalizing_flows/bijections/continuous/base.py index 8980ea4..12caf68 100644 --- a/normalizing_flows/bijections/continuous/base.py +++ b/normalizing_flows/bijections/continuous/base.py @@ -159,7 +159,7 @@ def forward(self, t, states): y = states[0] self._n_evals += 1 - t = torch.tensor(t).type_as(y) + t = torch.as_tensor(t).type_as(y) with torch.enable_grad(): y.requires_grad_(True) @@ -198,7 +198,7 @@ def forward(self, t, states): y = states[0] self._n_evals += 1 - t = torch.tensor(t).type_as(y) + t = torch.as_tensor(t).type_as(y) if self.hutch_noise is None: self.hutch_noise = torch.randn_like(y) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 4c28642..ac5f5d5 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,3 +1,6 @@ +from copy import deepcopy +from typing import Union, Tuple + import torch import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader @@ -5,31 +8,65 @@ from normalizing_flows.bijections.base import Bijection from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection from normalizing_flows.regularization import reconstruction_error -from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event +from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event, create_data_loader class Flow(nn.Module): + """ + Normalizing flow class. + + This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). + A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. + """ + def __init__(self, bijection: Bijection): + """ + + :param bijection: transformation component of the normalizing flow. + """ super().__init__() self.register_module('bijection', bijection) self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) @property - def base(self): + def base(self) -> torch.distributions.Distribution: + """ + :return: base distribution of the normalizing flow. + """ return torch.distributions.MultivariateNormal(loc=self.loc, covariance_matrix=self.covariance_matrix) - def base_log_prob(self, z): + def base_log_prob(self, z: torch.Tensor): + """ + Compute the log probability of input z under the base distribution. + + :param z: input tensor. + :return: log probability of the input tensor. + """ zf = flatten_event(z, self.bijection.event_shape) log_prob = self.base.log_prob(zf) return log_prob - def base_sample(self, sample_shape): + def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]): + """ + Sample from the base distribution. + + :param sample_shape: desired shape of sampled tensor. + :return: tensor with shape sample_shape. + """ z_flat = self.base.sample(sample_shape) z = unflatten_event(z_flat, self.bijection.event_shape) return z def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Transform the input x to the space of the base distribution. + + :param x: input tensor. + :param context: context tensor upon which the transformation is conditioned. + :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the + transformation. + """ if context is not None: assert context.shape[0] == x.shape[0] context = context.to(self.loc) @@ -38,17 +75,26 @@ def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): return z, log_base + log_det def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Compute the logarithm of the probability density of input x according to the normalizing flow. + + :param x: input tensor. + :param context: context tensor. + :return: + """ return self.forward_with_log_prob(x, context)[1] def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): """ - If context given, sample n vectors for each context vector. - Otherwise, sample n vectors. + Sample from the normalizing flow. - :param n: - :param context: - :param no_grad: - :return: + If context given, sample n tensors for each context tensor. + Otherwise, sample n tensors. + + :param n: number of tensors to sample. + :param context: context tensor with shape c. + :param no_grad: if True, do not track gradients in the inverse pass. + :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. """ if context is not None: z = self.base_sample(sample_shape=torch.Size((n, len(context)))) @@ -76,51 +122,130 @@ def fit(self, batch_size: int = 1024, shuffle: bool = True, show_progress: bool = False, - w_train: torch.Tensor = None): + w_train: torch.Tensor = None, + context_train: torch.Tensor = None, + x_val: torch.Tensor = None, + w_val: torch.Tensor = None, + context_val: torch.Tensor = None, + keep_best_weights: bool = True, + early_stopping: bool = False, + early_stopping_threshold: int = 50): """ + Fit the normalizing flow. - :param x_train: - :param n_epochs: + Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. + Bijection parameters are iteratively updated for a specified number of epochs. + If context data is provided, the normalizing flow learns the distribution of data conditional on context data. + + :param x_train: training data with shape (n_training_data, *event_shape). + :param n_epochs: perform fitting for this many steps. :param lr: learning rate. In general, lower learning rates are recommended for high-parametric bijections. - :param batch_size: - :param shuffle: - :param show_progress: - :param w_train: training data weights - :return: + :param batch_size: in each epoch, split training data into batches of this size and perform a parameter update for each batch. + :param shuffle: shuffle training data. This helps avoid incorrect fitting if nearby training samples are similar. + :param show_progress: show a progress bar with the current batch loss. + :param w_train: training data weights with shape (n_training_data,). + :param context_train: training data context tensor with shape (n_training_data, *context_shape). + :param x_val: validation data with shape (n_validation_data, *event_shape). + :param w_val: validation data weights with shape (n_validation_data,). + :param context_val: validation data context tensor with shape (n_validation_data, *context_shape). + :param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data. + :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. + :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ - if w_train is None: - batch_shape = get_batch_shape(x_train, self.bijection.event_shape) - w_train = torch.ones(batch_shape) + # Compute the number of event dimensions + n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) + + # Set the default batch size if batch_size is None: batch_size = len(x_train) - optimizer = torch.optim.AdamW(self.parameters(), lr=lr) - dataset = TensorDataset(x_train, w_train) - data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) - - if show_progress: - iterator = tqdm(range(n_epochs), desc='Fitting NF') - else: - iterator = range(n_epochs) + # Process training data + train_loader = create_data_loader( + x_train, + w_train, + context_train, + "training", + batch_size=batch_size, + shuffle=shuffle + ) + + # Process validation data + if x_val is not None: + val_loader = create_data_loader( + x_val, + w_val, + context_val, + "validation", + batch_size=batch_size, + shuffle=shuffle + ) + + best_val_loss = torch.inf + best_epoch = 0 + best_weights = deepcopy(self.state_dict()) + + def compute_batch_loss(batch_, reduction: callable = torch.mean): + batch_x, batch_weights = batch_[:2] + batch_context = batch_[2] if len(batch_) == 3 else None + + batch_log_prob = self.log_prob(batch_x.to(self.loc), context=batch_context) + batch_weights = batch_weights.to(self.loc) + assert batch_log_prob.shape == batch_weights.shape + batch_loss = -reduction(batch_log_prob * batch_weights) / n_event_dims + + return batch_loss + + iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress) + optimizer = torch.optim.AdamW(self.parameters(), lr=lr) + val_loss = None - for _ in iterator: - for batch_x, batch_w in data_loader: + for epoch in iterator: + for train_batch in train_loader: optimizer.zero_grad() - - log_prob = self.log_prob(batch_x.to(self.loc)) # TODO context! - w = batch_w.to(self.loc) - assert log_prob.shape == w.shape - loss = -torch.mean(log_prob * w) / n_event_dims - + train_loss = compute_batch_loss(train_batch, reduction=torch.mean) if hasattr(self.bijection, 'regularization'): - loss += self.bijection.regularization() - - loss.backward() + train_loss += self.bijection.regularization() + train_loss.backward() optimizer.step() if show_progress: - iterator.set_postfix_str(f'Loss: {loss:.4f}') + if val_loss is None: + iterator.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') + else: + iterator.set_postfix_str( + f'Training loss (batch): {train_loss:.4f}, ' + f'Validation loss: {val_loss:.4f}' + ) + + # Compute validation loss at the end of each epoch + # Validation loss will be displayed at the start of the next epoch + if x_val is not None: + with torch.no_grad(): + # Compute validation loss + val_loss = 0.0 + for val_batch in val_loader: + n_batch_data = len(val_batch[0]) + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data + if hasattr(self.bijection, 'regularization'): + val_loss += self.bijection.regularization() + + # Check if validation loss is the lowest so far + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + + # Store current weights + if keep_best_weights: + if best_epoch == epoch: + best_weights = deepcopy(self.state_dict()) + + # Optionally stop training early + if early_stopping: + if epoch - best_epoch > early_stopping_threshold: + break + + if x_val is not None and keep_best_weights: + self.load_state_dict(best_weights) def variational_fit(self, target, diff --git a/normalizing_flows/utils.py b/normalizing_flows/utils.py index 56f84d7..7a0a5e1 100644 --- a/normalizing_flows/utils.py +++ b/normalizing_flows/utils.py @@ -1,6 +1,6 @@ -from typing import Tuple, Union - +from typing import Tuple, Union, Optional import torch +from torch.utils.data import TensorDataset, DataLoader def pad_leading_dims(x: torch.Tensor, n_dims: int): @@ -184,3 +184,36 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): return super().log_prob(value - self.minimum) + + +def create_data_loader(x: torch.Tensor, + weights: Optional[torch.Tensor], + context: Optional[torch.Tensor], + label: str, + **kwargs): + """ + Creates a DataLoader object for NF training. + """ + assert label in ["training", "validation", "testing"] + + # Set default weights + if weights is None: + weights = torch.ones(len(x)) + + # Create the training dataset and loader + if len(x) != len(weights): + raise ValueError( + f"Expected same number of {label} data and {label} weights, " + f"but found {len(x)} and {len(weights)}" + ) + if context is None: + dataset = TensorDataset(x, weights) + else: + if len(x) != len(context): + raise ValueError( + f"Expected same number of {label} data and {label} contexts, " + f"but found {len(x)} and {len(context)}" + ) + dataset = TensorDataset(x, weights, context) + loader = DataLoader(dataset, **kwargs) + return loader diff --git a/notebooks/fitting_with_validation_data.ipynb b/notebooks/fitting_with_validation_data.ipynb new file mode 100644 index 0000000..fa60bad --- /dev/null +++ b/notebooks/fitting_with_validation_data.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Fitting with validation data \n", + "\n", + "This notebook shows how using validation data can improve the normalizing flow fit.\n", + "\n", + "We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. " + ], + "metadata": { + "collapsed": false + }, + "id": "ea05f26c641401d0" + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.440521400Z", + "start_time": "2023-11-09T21:11:08.267904100Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from normalizing_flows.flows import Flow\n", + "from normalizing_flows.bijections import RealNVP" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "# Create some synthetic training and validation data\n", + "torch.manual_seed(0)\n", + "\n", + "event_shape = (10,)\n", + "n_train = 100\n", + "n_val = 20\n", + "n_test = 10000\n", + "\n", + "x_train = torch.randn(n_train, *event_shape) * 2 + 4\n", + "x_val = torch.randn(n_val, *event_shape) * 2 + 4\n", + "x_test = torch.randn(n_test, *event_shape) * 2 + 4" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.456694900Z", + "start_time": "2023-11-09T21:11:10.445522900Z" + } + }, + "id": "21b252329b5695cf" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:15<00:00, 32.71it/s, Training loss (batch): 1.7106]\n" + ] + } + ], + "source": [ + "# Train without validation data\n", + "torch.manual_seed(0)\n", + "flow0 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow0.fit(x_train, show_progress=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:25.777575300Z", + "start_time": "2023-11-09T21:11:10.457694Z" + } + }, + "id": "b8c5703314f84814" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:23<00:00, 21.42it/s, Training loss (batch): 1.7630, Validation loss: 2.8325]\n" + ] + } + ], + "source": [ + "# Train with validation data and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow1 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow1.fit(x_train, show_progress=True, x_val=x_val)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:49.164216Z", + "start_time": "2023-11-09T21:11:25.775746200Z" + } + }, + "id": "95d4d4e0447f1d4" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 39%|███▉ | 194/500 [00:11<00:18, 16.57it/s, Training loss (batch): 1.9825, Validation loss: 2.1353]\n" + ] + } + ], + "source": [ + "# Train with validation data, early stopping, and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow2 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:00.931776800Z", + "start_time": "2023-11-09T21:11:49.165794100Z" + } + }, + "id": "2a6ff6eaea4e1323" + }, + { + "cell_type": "markdown", + "source": [ + "The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping." + ], + "metadata": { + "collapsed": false + }, + "id": "84366140ce6804fe" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test loss values\n", + "\n", + "Without validation data: 55.78230667114258\n", + "With validation data, no early stopping: 24.563425064086914\n", + "With validation data, early stopping: 24.563425064086914\n" + ] + } + ], + "source": [ + "print(\"Test loss values\")\n", + "print()\n", + "print(f\"Without validation data: {torch.mean(-flow0.log_prob(x_test))}\")\n", + "print(f\"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}\")\n", + "print(f\"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:01.263469Z", + "start_time": "2023-11-09T21:12:00.925959700Z" + } + }, + "id": "bfaca2ae85997ee3" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/speed_check.py b/speed_check.py new file mode 100644 index 0000000..1697124 --- /dev/null +++ b/speed_check.py @@ -0,0 +1,92 @@ +# Test the speed of standard NF operations + +import torch +import timeit +import matplotlib.pyplot as plt + +from normalizing_flows import Flow +from normalizing_flows.architectures import ( + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + UMNNMAF, + DeepDiffeomorphicBijection, + RNODE, + FFJORD, + OTFlow, + ResFlow, + ProximalResFlow, + InvertibleResNet +) + + +def avg_eval_time(flow: Flow, n_repeats: int = 30): + total_time = timeit.timeit(lambda: flow.log_prob(x), number=n_repeats) + return total_time / n_repeats + + +def avg_sampling_time(flow: Flow, batch_size: int = 100, n_repeats: int = 30): + total_time = timeit.timeit(lambda: flow.sample(batch_size), number=n_repeats) + return total_time / n_repeats + + +if __name__ == '__main__': + torch.manual_seed(0) + batch_shape = (100,) + event_shape = (50,) + x = torch.randn(*batch_shape, *event_shape) + + eval_times = {} + sample_times = {} + for bijection_class in [ + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + # UMNNMAF, # Too slow + DeepDiffeomorphicBijection, + RNODE, + FFJORD, + OTFlow, + ResFlow, + ProximalResFlow, + InvertibleResNet + ]: + f = Flow(bijection_class(event_shape)) + + name = bijection_class.__name__ + e_avg = avg_eval_time(f) + s_avg = avg_sampling_time(f) + + print(f'{name:<30}\t| e: {e_avg:.4f}\t| s: {s_avg:.4f}') + eval_times[name] = e_avg + sample_times[name] = s_avg + + plt.figure() + plt.bar(list(eval_times.keys()), list(eval_times.values())) + plt.ylabel("log_prob time [s]") + plt.xlabel("Bijection") + plt.xticks(rotation=30) + plt.tight_layout() + plt.show() + + plt.figure() + plt.bar(list(sample_times.keys()), list(sample_times.values())) + plt.ylabel("Sampling time [s]") + plt.xlabel("Bijection") + plt.xticks(rotation=30) + plt.tight_layout() + plt.show() diff --git a/test/test_fit.py b/test/test_fit.py index 98c9fee..ebf89f3 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -3,6 +3,7 @@ from normalizing_flows import Flow from normalizing_flows.bijections import NICE, RealNVP, MAF, ElementwiseAffine, ElementwiseShift, ElementwiseRQSpline, \ CouplingRQNSF, MaskedAutoregressiveRQNSF, LowerTriangular, ElementwiseScale, QR, LU +from test.constants import __test_constants @pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent') @@ -107,3 +108,65 @@ def test_diagonal_gaussian_1(bijection_class): relative_error = max((x_std - sigma.ravel()).abs() / sigma.ravel()) assert relative_error < 0.1 + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +def test_fit_basic(n_train, event_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2) + + +@pytest.mark.parametrize("n_train", [1, 10, 4000]) +@pytest.mark.parametrize("n_val", [1, 10, 2400]) +def test_fit_with_validation_data(n_train, n_val): + torch.manual_seed(0) + + event_shape = (2, 3) + + x_train = torch.randn(size=(n_train, *event_shape)) + x_val = torch.randn(size=(n_val, *event_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, x_val=x_val) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_training_context(n_train, event_shape, context_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("n_val", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_context_and_validation_data(n_train, n_val, event_shape, context_shape): + torch.manual_seed(0) + + # Setup training data + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + + # Setup validation data + x_val = torch.randn(size=(n_val, *event_shape)) + if context_shape is None: + c_val = None + else: + c_val = torch.randn(size=(n_val, *context_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train, x_val=x_val, context_val=c_val)