From 1c36e5799d7a135625b89fe22b8ba9d9d78a7264 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 29 Oct 2023 21:47:07 +0100 Subject: [PATCH 01/11] Simplify README.md --- README.md | 91 +++++++++++++++++++++++++++---------------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index d2a2ee7..6643b23 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,10 @@ # Normalizing flows in PyTorch This package implements normalizing flows and their building blocks. -The package is meant for researchers, enabling: +It allows: -* easy use of normalizing flows as generative models or density estimators in various applications; -* systematic comparisons of normalizing flows or their building blocks; -* simple implementation of new normalizing flows which belong to either the autoregressive, residual, or continuous - families; +* easy use of normalizing flows as trainable distributions; +* easy implementation of new normalizing flows. Example use: @@ -53,48 +51,51 @@ We support Python versions 3.7 and upwards. ## Brief background -A normalizing flow (NF) is a flexible distribution, defined as a bijective transformation of a simple statistical -distribution. -The simple distribution is typically a standard Gaussian. -The transformation is typically an invertible neural network that can make the NF arbitrarily complex. -Training a NF using a dataset means optimizing the parameters of transformation to make the dataset likely under the NF. +A normalizing flow (NF) is a flexible trainable distribution. +It is defined as a bijective transformation of a simple distribution, such as a standard Gaussian. +The bijection is typically an invertible neural network. +Training a NF using a dataset means optimizing the bijection's parameters to make the dataset likely under the NF. We can use a NF to compute the probability of a data point or to independently sample data from the process that generated our dataset. -A NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as: -$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|$$ - -## Implemented architectures - -We implement the following NF transformations: - -| Bijection | Inverse | Log determinant | Inverse implemented | -|---------------------------------------------------------------------|:-----------:|:-----------------------:|:-------------------:| -| [NICE](http://arxiv.org/abs/1410.8516) | Exact | Exact | Yes | -| [Real NVP](http://arxiv.org/abs/1605.08803) | Exact | Exact | Yes | -| [MAF](http://arxiv.org/abs/1705.07057) | Exact | Exact | Yes | -| [IAF](http://arxiv.org/abs/1606.04934) | Exact | Exact | Yes | -| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Exact | Exact | Yes | -| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Exact | Exact | Yes | -| [NAF](http://arxiv.org/abs/1804.00779) | | | | -| [Block NAF](http://arxiv.org/abs/1904.04676) | | | | -| [UMNN](http://arxiv.org/abs/1908.05164) | Approximate | Exact | No | -| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Approximate | Exact | No | -| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Approximate | Exact | No | -| [Sylvester](http://arxiv.org/abs/1803.05649) | Approximate | Exact | No | -| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Approximate | Biased approximation | Yes | -| [ResFlow](http://arxiv.org/abs/1906.02735) | Approximate | Unbiased approximation | Yes | -| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Approximate | Exact (if single layer) | Yes | -| [FFJORD](http://arxiv.org/abs/1810.01367) | Approximate | Approximate | Yes | -| [RNODE](http://arxiv.org/abs/2002.02798) | Approximate | Approximate | Yes | -| [DDNF](http://arxiv.org/abs/1810.03256) | Approximate | Approximate | Yes | -| [OT flow](http://arxiv.org/abs/2006.00104) | Approximate | Exact | Yes | - -Note: inverse approximations can be made arbitrarily accurate with stricter convergence conditions. -Architectures without an implemented inverse support either sampling or density estimation, but not both at once. -Such architectures are unsuitable for downstream tasks which require both functionalities. - -We also implement simple bijections that can be used in the same manner: +The density of a NF $q(x)$ with the bijection $f(z) = x$ and base distribution $p(z)$ is defined as: +$$\log q(x) = \log p(f^{-1}(x)) + \log\left|\det J_{f^{-1}}(x)\right|.$$ +Sampling from a NF means sampling from the simple distribution and transforming the sample using the bijection. + +## Supported architectures + +We list supported NF architectures below. +We classify architectures as either autoregressive, residual, or continuous; as defined +by [Papamakarios et al. (2021)](https://arxiv.org/abs/1912.02762). +Exact architectures do not use numerical approximations to generate data or compute the log density. + +| Architecture | Bijection type | Exact | Two-way | +|--------------------------------------------------------------------------|:--------------------------:|:-------:|:-------:| +| [NICE](http://arxiv.org/abs/1410.8516) | Autoregressive | ✔ | ✔ | +| [Real NVP](http://arxiv.org/abs/1605.08803) | Autoregressive | ✔ | ✔ | +| [MAF](http://arxiv.org/abs/1705.07057) | Autoregressive | ✔ | ✔ | +| [IAF](http://arxiv.org/abs/1606.04934) | Autoregressive | ✔ | ✔ | +| [Rational quadratic NSF](http://arxiv.org/abs/1906.04032) | Autoregressive | ✔ | ✔ | +| [Linear rational NSF](http://arxiv.org/abs/2001.05168) | Autoregressive | ✔ | ✔ | +| [NAF](http://arxiv.org/abs/1804.00779) | Autoregressive | ✗ | ✔ | +| [UMNN](http://arxiv.org/abs/1908.05164) | Autoregressive | ✗ | ✔ | +| [Planar](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.21423) | Residual | ✗ | ✗ | +| [Radial](https://proceedings.mlr.press/v37/rezende15.html) | Residual | ✗ | ✗ | +| [Sylvester](http://arxiv.org/abs/1803.05649) | Residual | ✗ | ✗ | +| [Invertible ResNet](http://arxiv.org/abs/1811.00995) | Residual | ✗ | ✔* | +| [ResFlow](http://arxiv.org/abs/1906.02735) | Residual | ✗ | ✔* | +| [Proximal ResFlow](http://arxiv.org/abs/2211.17158) | Residual | ✗ | ✔* | +| [FFJORD](http://arxiv.org/abs/1810.01367) | Continuous | ✗ | ✔* | +| [RNODE](http://arxiv.org/abs/2002.02798) | Continuous | ✗ | ✔* | +| [DDNF](http://arxiv.org/abs/1810.03256) | Continuous | ✗ | ✔* | +| [OT flow](http://arxiv.org/abs/2006.00104) | Continuous | ✗ | ✔ | + +Two-way architectures support both sampling and density estimation. +Two-way architectures marked with an asterisk (*) support both, but use a numerical approximation to sample or estimate +density. +One-way architectures support either sampling or density estimation, but not both at once. + +We also support simple bijections (all exact and two-way): * Permutation * Elementwise translation (shift vector) @@ -102,5 +103,3 @@ We also implement simple bijections that can be used in the same manner: * Rotation (orthogonal matrix) * Triangular matrix * Dense matrix (using the QR or LU decomposition) - -All of these have exact inverses and log determinants. \ No newline at end of file From 3cdb1d0f13809352431c664ee06168ca19010b5b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 6 Nov 2023 14:26:05 -0800 Subject: [PATCH 02/11] Remove .gitignore --- .idea/.gitignore | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 .idea/.gitignore diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml From 02ec3cbebb9af6931373b2e4fa410733b64e5a0f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 7 Nov 2023 09:02:53 -0800 Subject: [PATCH 03/11] Add architectures.py, add flow speed test --- normalizing_flows/architectures.py | 20 +++++++ speed_test.py | 92 ++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 normalizing_flows/architectures.py create mode 100644 speed_test.py diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py new file mode 100644 index 0000000..83a0295 --- /dev/null +++ b/normalizing_flows/architectures.py @@ -0,0 +1,20 @@ +from normalizing_flows.bijections.finite.autoregressive.architectures import ( + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + UMNNMAF +) + +from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection +from normalizing_flows.bijections.continuous.rnode import RNODE +from normalizing_flows.bijections.continuous.ffjord import FFJORD +from normalizing_flows.bijections.continuous.otflow import OTFlow + +from normalizing_flows.bijections.finite.residual.architectures import ResFlow, ProximalResFlow, InvertibleResNet diff --git a/speed_test.py b/speed_test.py new file mode 100644 index 0000000..1697124 --- /dev/null +++ b/speed_test.py @@ -0,0 +1,92 @@ +# Test the speed of standard NF operations + +import torch +import timeit +import matplotlib.pyplot as plt + +from normalizing_flows import Flow +from normalizing_flows.architectures import ( + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + UMNNMAF, + DeepDiffeomorphicBijection, + RNODE, + FFJORD, + OTFlow, + ResFlow, + ProximalResFlow, + InvertibleResNet +) + + +def avg_eval_time(flow: Flow, n_repeats: int = 30): + total_time = timeit.timeit(lambda: flow.log_prob(x), number=n_repeats) + return total_time / n_repeats + + +def avg_sampling_time(flow: Flow, batch_size: int = 100, n_repeats: int = 30): + total_time = timeit.timeit(lambda: flow.sample(batch_size), number=n_repeats) + return total_time / n_repeats + + +if __name__ == '__main__': + torch.manual_seed(0) + batch_shape = (100,) + event_shape = (50,) + x = torch.randn(*batch_shape, *event_shape) + + eval_times = {} + sample_times = {} + for bijection_class in [ + NICE, + RealNVP, + MAF, + IAF, + CouplingRQNSF, + MaskedAutoregressiveRQNSF, + InverseAutoregressiveRQNSF, + CouplingLRS, + MaskedAutoregressiveLRS, + CouplingDSF, + # UMNNMAF, # Too slow + DeepDiffeomorphicBijection, + RNODE, + FFJORD, + OTFlow, + ResFlow, + ProximalResFlow, + InvertibleResNet + ]: + f = Flow(bijection_class(event_shape)) + + name = bijection_class.__name__ + e_avg = avg_eval_time(f) + s_avg = avg_sampling_time(f) + + print(f'{name:<30}\t| e: {e_avg:.4f}\t| s: {s_avg:.4f}') + eval_times[name] = e_avg + sample_times[name] = s_avg + + plt.figure() + plt.bar(list(eval_times.keys()), list(eval_times.values())) + plt.ylabel("log_prob time [s]") + plt.xlabel("Bijection") + plt.xticks(rotation=30) + plt.tight_layout() + plt.show() + + plt.figure() + plt.bar(list(sample_times.keys()), list(sample_times.values())) + plt.ylabel("Sampling time [s]") + plt.xlabel("Bijection") + plt.xticks(rotation=30) + plt.tight_layout() + plt.show() From a7f40759bbd6d498a8a286f8f9a1d37f5698995d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 7 Nov 2023 09:05:08 -0800 Subject: [PATCH 04/11] Update NF examples --- README.md | 4 +++- examples/Computing log determinants.md | 2 +- examples/Modifying architectures.md | 2 +- examples/Training a normalizing flow.md | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6643b23..e3931fd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,9 @@ Example use: ```python import torch -from normalizing_flows import RealNVP, Flow +from normalizing_flows import Flow +from normalizing_flows.architectures import RealNVP + torch.manual_seed(0) diff --git a/examples/Computing log determinants.md b/examples/Computing log determinants.md index 17e236c..27a37dc 100644 --- a/examples/Computing log determinants.md +++ b/examples/Computing log determinants.md @@ -7,7 +7,7 @@ The code is as follows: ```python import torch from normalizing_flows import Flow -from normalizing_flows.bijections import RealNVP +from normalizing_flows.architectures import RealNVP torch.manual_seed(0) diff --git a/examples/Modifying architectures.md b/examples/Modifying architectures.md index aec3d68..f8794dc 100644 --- a/examples/Modifying architectures.md +++ b/examples/Modifying architectures.md @@ -4,7 +4,7 @@ We give an example on how to modify a bijection's architecture. We use the Masked Autoregressive Flow (MAF) as an example. We can manually set the number of invertible layers as follows: ```python -from normalizing_flows.bijections import MAF +from normalizing_flows.architectures import MAF event_shape = (10,) flow = MAF(event_shape=event_shape, n_layers=5) diff --git a/examples/Training a normalizing flow.md b/examples/Training a normalizing flow.md index 1078bfc..291b965 100644 --- a/examples/Training a normalizing flow.md +++ b/examples/Training a normalizing flow.md @@ -7,7 +7,7 @@ The code is as follows: ```python import torch from normalizing_flows import Flow -from normalizing_flows.bijections import RealNVP +from normalizing_flows.architectures import RealNVP torch.manual_seed(0) From 944ef9accbbf3334d7f53b5d8776a33859af003f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 7 Nov 2023 12:56:24 -0800 Subject: [PATCH 05/11] Fix t call --- normalizing_flows/bijections/continuous/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/continuous/base.py b/normalizing_flows/bijections/continuous/base.py index 8980ea4..12caf68 100644 --- a/normalizing_flows/bijections/continuous/base.py +++ b/normalizing_flows/bijections/continuous/base.py @@ -159,7 +159,7 @@ def forward(self, t, states): y = states[0] self._n_evals += 1 - t = torch.tensor(t).type_as(y) + t = torch.as_tensor(t).type_as(y) with torch.enable_grad(): y.requires_grad_(True) @@ -198,7 +198,7 @@ def forward(self, t, states): y = states[0] self._n_evals += 1 - t = torch.tensor(t).type_as(y) + t = torch.as_tensor(t).type_as(y) if self.hutch_noise is None: self.hutch_noise = torch.randn_like(y) From b04691a4bc414d1d18147b54e2f7abbfc33c44ad Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 7 Nov 2023 16:21:30 -0800 Subject: [PATCH 06/11] Rename script to avoid pytest conflict --- speed_test.py => speed_check.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename speed_test.py => speed_check.py (100%) diff --git a/speed_test.py b/speed_check.py similarity index 100% rename from speed_test.py rename to speed_check.py From e7d5b9241f45d49afb164600e8fc1ce444b35a44 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 7 Nov 2023 18:35:32 -0800 Subject: [PATCH 07/11] Add documentation --- normalizing_flows/flows.py | 62 ++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 4c28642..37f27ac 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,3 +1,5 @@ +from typing import Union, Tuple + import torch import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader @@ -9,27 +11,60 @@ class Flow(nn.Module): + """ + Normalizing flow class. + + This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). + A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. + """ def __init__(self, bijection: Bijection): + """ + + :param bijection: transformation component of the normalizing flow. + """ super().__init__() self.register_module('bijection', bijection) self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) @property - def base(self): + def base(self) -> torch.distributions.Distribution: + """ + :return: base distribution of the normalizing flow. + """ return torch.distributions.MultivariateNormal(loc=self.loc, covariance_matrix=self.covariance_matrix) - def base_log_prob(self, z): + def base_log_prob(self, z: torch.Tensor): + """ + Compute the log probability of input z under the base distribution. + + :param z: input tensor. + :return: log probability of the input tensor. + """ zf = flatten_event(z, self.bijection.event_shape) log_prob = self.base.log_prob(zf) return log_prob - def base_sample(self, sample_shape): + def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]): + """ + Sample from the base distribution. + + :param sample_shape: desired shape of sampled tensor. + :return: tensor with shape sample_shape. + """ z_flat = self.base.sample(sample_shape) z = unflatten_event(z_flat, self.bijection.event_shape) return z def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Transform the input x to the space of the base distribution. + + :param x: input tensor. + :param context: context tensor upon which the transformation is conditioned. + :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the + transformation. + """ if context is not None: assert context.shape[0] == x.shape[0] context = context.to(self.loc) @@ -38,17 +73,26 @@ def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): return z, log_base + log_det def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Compute the logarithm of the probability density of input x according to the normalizing flow. + + :param x: input tensor. + :param context: context tensor. + :return: + """ return self.forward_with_log_prob(x, context)[1] def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): """ - If context given, sample n vectors for each context vector. - Otherwise, sample n vectors. + Sample from the normalizing flow. - :param n: - :param context: - :param no_grad: - :return: + If context given, sample n tensors for each context tensor. + Otherwise, sample n tensors. + + :param n: number of tensors to sample. + :param context: context tensor with shape c. + :param no_grad: if True, do not track gradients in the inverse pass. + :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. """ if context is not None: z = self.base_sample(sample_shape=torch.Size((n, len(context)))) From e058ba7774615b2c38ba9de1c1675c5be7e4a3ce Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 12:15:46 -0800 Subject: [PATCH 08/11] Add support for validation data in Flow.fit --- normalizing_flows/flows.py | 116 ++++++++++++++++++++++++++----------- normalizing_flows/utils.py | 35 ++++++++++- 2 files changed, 117 insertions(+), 34 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 37f27ac..dedf92e 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -7,7 +7,7 @@ from normalizing_flows.bijections.base import Bijection from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection from normalizing_flows.regularization import reconstruction_error -from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event +from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event, create_data_loader class Flow(nn.Module): @@ -17,6 +17,7 @@ class Flow(nn.Module): This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. """ + def __init__(self, bijection: Bijection): """ @@ -120,51 +121,100 @@ def fit(self, batch_size: int = 1024, shuffle: bool = True, show_progress: bool = False, - w_train: torch.Tensor = None): + w_train: torch.Tensor = None, + context_train: torch.Tensor = None, + x_val: torch.Tensor = None, + w_val: torch.Tensor = None, + context_val: torch.Tensor = None): """ + Fit the normalizing flow. - :param x_train: - :param n_epochs: + Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. + Bijection parameters are iteratively updated for a specified number of epochs. + If validation data is provided, we keep the bijection weights with the highest probability of validation data. + If context data is provided, the normalizing flow learns the distribution of data conditional on context data. + + :param x_train: training data with shape (n_training_data, *event_shape). + :param n_epochs: perform fitting for this many steps. :param lr: learning rate. In general, lower learning rates are recommended for high-parametric bijections. - :param batch_size: - :param shuffle: - :param show_progress: - :param w_train: training data weights - :return: + :param batch_size: in each epoch, split training data into batches of this size and perform a parameter update for each batch. + :param shuffle: shuffle training data. This helps avoid incorrect fitting if nearby training samples are similar. + :param show_progress: show a progress bar with the current batch loss. + :param w_train: training data weights with shape (n_training_data,). + :param context_train: training data context tensor with shape (n_training_data, *context_shape). + :param x_val: validation data with shape (n_validation_data, *event_shape). + :param w_val: validation data weights with shape (n_validation_data,). + :param context_val: validation data context tensor with shape (n_validation_data, *context_shape). """ - if w_train is None: - batch_shape = get_batch_shape(x_train, self.bijection.event_shape) - w_train = torch.ones(batch_shape) - if batch_size is None: - batch_size = len(x_train) - optimizer = torch.optim.AdamW(self.parameters(), lr=lr) - dataset = TensorDataset(x_train, w_train) - data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) - + # Compute the number of event dimensions n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) - if show_progress: - iterator = tqdm(range(n_epochs), desc='Fitting NF') - else: - iterator = range(n_epochs) + # Set the default batch size + if batch_size is None: + batch_size = len(x_train) + # Process training data + train_loader = create_data_loader( + x_train, + w_train, + context_train, + "training", + batch_size=batch_size, + shuffle=shuffle + ) + + # Process validation data + val_loader = create_data_loader( + x_val, + w_val, + context_val, + "validation", + batch_size=batch_size, + shuffle=shuffle + ) + + def compute_batch_loss(batch_, reduction: callable = torch.mean): + batch_x, batch_weights = batch_[:2] + batch_context = batch_[2] if len(batch_) == 3 else None + + batch_log_prob = self.log_prob(batch_x.to(self.loc), context=batch_context) + batch_weights = batch_weights.to(self.loc) + assert batch_log_prob.shape == batch_weights.shape + batch_loss = -reduction(batch_log_prob * batch_weights) / n_event_dims + + return batch_loss + + iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress) + optimizer = torch.optim.AdamW(self.parameters(), lr=lr) + val_loss = None for _ in iterator: - for batch_x, batch_w in data_loader: + for train_batch in train_loader: optimizer.zero_grad() - - log_prob = self.log_prob(batch_x.to(self.loc)) # TODO context! - w = batch_w.to(self.loc) - assert log_prob.shape == w.shape - loss = -torch.mean(log_prob * w) / n_event_dims - + train_loss = compute_batch_loss(train_batch, reduction=torch.mean) if hasattr(self.bijection, 'regularization'): - loss += self.bijection.regularization() - - loss.backward() + train_loss += self.bijection.regularization() + train_loss.backward() optimizer.step() if show_progress: - iterator.set_postfix_str(f'Loss: {loss:.4f}') + if val_loss is None: + iterator.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') + else: + iterator.set_postfix_str( + f'Training loss (batch): {train_loss:.4f}, ' + f'Validation loss: {val_loss:.4f}' + ) + + # Compute validation loss at the end of each epoch + # Validation loss will be displayed at the start of the next epoch + if x_val is not None: + with torch.no_grad(): + val_loss = 0.0 + for val_batch in val_loader: + n_batch_data = len(val_batch[0]) + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data + if hasattr(self.bijection, 'regularization'): + val_loss += self.bijection.regularization() def variational_fit(self, target, diff --git a/normalizing_flows/utils.py b/normalizing_flows/utils.py index 56f84d7..db96245 100644 --- a/normalizing_flows/utils.py +++ b/normalizing_flows/utils.py @@ -1,6 +1,6 @@ from typing import Tuple, Union - import torch +from torch.utils.data import TensorDataset, DataLoader def pad_leading_dims(x: torch.Tensor, n_dims: int): @@ -184,3 +184,36 @@ def sample(self, sample_shape=torch.Size()): def log_prob(self, value): return super().log_prob(value - self.minimum) + + +def create_data_loader(x: torch.Tensor, + weights: torch.Tensor, + context: torch.Tensor, + label: str, + **kwargs): + """ + Creates a DataLoader object for NF training. + """ + assert label in ["training", "validation", "testing"] + + # Set default weights + if weights is None: + weights = torch.ones(len(x)) + + # Create the training dataset and loader + if len(x) != len(weights): + raise ValueError( + f"Expected same number of {label} data and {label} weights, " + f"but found {len(x)} and {len(weights)}" + ) + if context is None: + dataset = TensorDataset(x, weights) + else: + if len(x) != len(context): + raise ValueError( + f"Expected same number of {label} data and {label} contexts, " + f"but found {len(x)} and {len(context)}" + ) + dataset = TensorDataset(x, weights, context) + loader = DataLoader(dataset, **kwargs) + return loader From 54574a848df2ca77725ee3d8b1584c05f42aecef Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 12:30:33 -0800 Subject: [PATCH 09/11] Check if validation data exists in Flow.fit, add tests --- normalizing_flows/flows.py | 17 +++++----- normalizing_flows/utils.py | 6 ++-- test/test_fit.py | 63 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index dedf92e..f803971 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -164,14 +164,15 @@ def fit(self, ) # Process validation data - val_loader = create_data_loader( - x_val, - w_val, - context_val, - "validation", - batch_size=batch_size, - shuffle=shuffle - ) + if x_val is not None: + val_loader = create_data_loader( + x_val, + w_val, + context_val, + "validation", + batch_size=batch_size, + shuffle=shuffle + ) def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] diff --git a/normalizing_flows/utils.py b/normalizing_flows/utils.py index db96245..7a0a5e1 100644 --- a/normalizing_flows/utils.py +++ b/normalizing_flows/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch from torch.utils.data import TensorDataset, DataLoader @@ -187,8 +187,8 @@ def log_prob(self, value): def create_data_loader(x: torch.Tensor, - weights: torch.Tensor, - context: torch.Tensor, + weights: Optional[torch.Tensor], + context: Optional[torch.Tensor], label: str, **kwargs): """ diff --git a/test/test_fit.py b/test/test_fit.py index 98c9fee..ebf89f3 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -3,6 +3,7 @@ from normalizing_flows import Flow from normalizing_flows.bijections import NICE, RealNVP, MAF, ElementwiseAffine, ElementwiseShift, ElementwiseRQSpline, \ CouplingRQNSF, MaskedAutoregressiveRQNSF, LowerTriangular, ElementwiseScale, QR, LU +from test.constants import __test_constants @pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent') @@ -107,3 +108,65 @@ def test_diagonal_gaussian_1(bijection_class): relative_error = max((x_std - sigma.ravel()).abs() / sigma.ravel()) assert relative_error < 0.1 + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +def test_fit_basic(n_train, event_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2) + + +@pytest.mark.parametrize("n_train", [1, 10, 4000]) +@pytest.mark.parametrize("n_val", [1, 10, 2400]) +def test_fit_with_validation_data(n_train, n_val): + torch.manual_seed(0) + + event_shape = (2, 3) + + x_train = torch.randn(size=(n_train, *event_shape)) + x_val = torch.randn(size=(n_val, *event_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, x_val=x_val) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_training_context(n_train, event_shape, context_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("n_val", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_context_and_validation_data(n_train, n_val, event_shape, context_shape): + torch.manual_seed(0) + + # Setup training data + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + + # Setup validation data + x_val = torch.randn(size=(n_val, *event_shape)) + if context_shape is None: + c_val = None + else: + c_val = torch.randn(size=(n_val, *context_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train, x_val=x_val, context_val=c_val) From 46385b5d1258be47776e9e07916611c6497d8739 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 12:48:36 -0800 Subject: [PATCH 10/11] Add "keep best weights" and "early stopping" options to Flow.fit and --- normalizing_flows/flows.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index f803971..ac5f5d5 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Union, Tuple import torch @@ -125,13 +126,15 @@ def fit(self, context_train: torch.Tensor = None, x_val: torch.Tensor = None, w_val: torch.Tensor = None, - context_val: torch.Tensor = None): + context_val: torch.Tensor = None, + keep_best_weights: bool = True, + early_stopping: bool = False, + early_stopping_threshold: int = 50): """ Fit the normalizing flow. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. Bijection parameters are iteratively updated for a specified number of epochs. - If validation data is provided, we keep the bijection weights with the highest probability of validation data. If context data is provided, the normalizing flow learns the distribution of data conditional on context data. :param x_train: training data with shape (n_training_data, *event_shape). @@ -145,6 +148,9 @@ def fit(self, :param x_val: validation data with shape (n_validation_data, *event_shape). :param w_val: validation data weights with shape (n_validation_data,). :param context_val: validation data context tensor with shape (n_validation_data, *context_shape). + :param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data. + :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. + :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ # Compute the number of event dimensions n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) @@ -174,6 +180,10 @@ def fit(self, shuffle=shuffle ) + best_val_loss = torch.inf + best_epoch = 0 + best_weights = deepcopy(self.state_dict()) + def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] batch_context = batch_[2] if len(batch_) == 3 else None @@ -188,7 +198,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) val_loss = None - for _ in iterator: + + for epoch in iterator: for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) @@ -210,6 +221,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Validation loss will be displayed at the start of the next epoch if x_val is not None: with torch.no_grad(): + # Compute validation loss val_loss = 0.0 for val_batch in val_loader: n_batch_data = len(val_batch[0]) @@ -217,6 +229,24 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if hasattr(self.bijection, 'regularization'): val_loss += self.bijection.regularization() + # Check if validation loss is the lowest so far + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + + # Store current weights + if keep_best_weights: + if best_epoch == epoch: + best_weights = deepcopy(self.state_dict()) + + # Optionally stop training early + if early_stopping: + if epoch - best_epoch > early_stopping_threshold: + break + + if x_val is not None and keep_best_weights: + self.load_state_dict(best_weights) + def variational_fit(self, target, n_epochs: int = 10, From 8c73cc8e77fd0d77ff47b53bded294066791066d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 13:12:57 -0800 Subject: [PATCH 11/11] Add notebook with validation data fitting example --- notebooks/fitting_with_validation_data.ipynb | 206 +++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 notebooks/fitting_with_validation_data.ipynb diff --git a/notebooks/fitting_with_validation_data.ipynb b/notebooks/fitting_with_validation_data.ipynb new file mode 100644 index 0000000..fa60bad --- /dev/null +++ b/notebooks/fitting_with_validation_data.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Fitting with validation data \n", + "\n", + "This notebook shows how using validation data can improve the normalizing flow fit.\n", + "\n", + "We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. " + ], + "metadata": { + "collapsed": false + }, + "id": "ea05f26c641401d0" + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.440521400Z", + "start_time": "2023-11-09T21:11:08.267904100Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from normalizing_flows.flows import Flow\n", + "from normalizing_flows.bijections import RealNVP" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "# Create some synthetic training and validation data\n", + "torch.manual_seed(0)\n", + "\n", + "event_shape = (10,)\n", + "n_train = 100\n", + "n_val = 20\n", + "n_test = 10000\n", + "\n", + "x_train = torch.randn(n_train, *event_shape) * 2 + 4\n", + "x_val = torch.randn(n_val, *event_shape) * 2 + 4\n", + "x_test = torch.randn(n_test, *event_shape) * 2 + 4" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.456694900Z", + "start_time": "2023-11-09T21:11:10.445522900Z" + } + }, + "id": "21b252329b5695cf" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:15<00:00, 32.71it/s, Training loss (batch): 1.7106]\n" + ] + } + ], + "source": [ + "# Train without validation data\n", + "torch.manual_seed(0)\n", + "flow0 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow0.fit(x_train, show_progress=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:25.777575300Z", + "start_time": "2023-11-09T21:11:10.457694Z" + } + }, + "id": "b8c5703314f84814" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:23<00:00, 21.42it/s, Training loss (batch): 1.7630, Validation loss: 2.8325]\n" + ] + } + ], + "source": [ + "# Train with validation data and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow1 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow1.fit(x_train, show_progress=True, x_val=x_val)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:49.164216Z", + "start_time": "2023-11-09T21:11:25.775746200Z" + } + }, + "id": "95d4d4e0447f1d4" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 39%|███▉ | 194/500 [00:11<00:18, 16.57it/s, Training loss (batch): 1.9825, Validation loss: 2.1353]\n" + ] + } + ], + "source": [ + "# Train with validation data, early stopping, and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow2 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:00.931776800Z", + "start_time": "2023-11-09T21:11:49.165794100Z" + } + }, + "id": "2a6ff6eaea4e1323" + }, + { + "cell_type": "markdown", + "source": [ + "The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping." + ], + "metadata": { + "collapsed": false + }, + "id": "84366140ce6804fe" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test loss values\n", + "\n", + "Without validation data: 55.78230667114258\n", + "With validation data, no early stopping: 24.563425064086914\n", + "With validation data, early stopping: 24.563425064086914\n" + ] + } + ], + "source": [ + "print(\"Test loss values\")\n", + "print()\n", + "print(f\"Without validation data: {torch.mean(-flow0.log_prob(x_test))}\")\n", + "print(f\"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}\")\n", + "print(f\"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:01.263469Z", + "start_time": "2023-11-09T21:12:00.925959700Z" + } + }, + "id": "bfaca2ae85997ee3" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}