From 474460fb51982f4a6b2ef315c2ecff8dad4cc036 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 13:51:25 +0100 Subject: [PATCH] Add FlowMixture class, rework flow class --- normalizing_flows/flows.py | 253 ++++++++++++++++++++++--------------- 1 file changed, 154 insertions(+), 99 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index fe27fab..788af2d 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,33 +1,21 @@ from copy import deepcopy -from typing import Union, Tuple +from typing import Union, Tuple, List +import numpy as np import torch import torch.nn as nn -from torch.utils.data import TensorDataset, DataLoader from tqdm import tqdm from normalizing_flows.bijections.base import Bijection -from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection -from normalizing_flows.regularization import reconstruction_error -from normalizing_flows.utils import flatten_event, get_batch_shape, unflatten_event, create_data_loader +from normalizing_flows.utils import flatten_event, unflatten_event, create_data_loader -class Flow(nn.Module): - """ - Normalizing flow class. - - This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). - A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. - """ - - def __init__(self, bijection: Bijection): - """ - - :param bijection: transformation component of the normalizing flow. - """ +class BaseFlow(nn.Module): + def __init__(self, event_shape): super().__init__() - self.register_module('bijection', bijection) - self.register_buffer('loc', torch.zeros(self.bijection.n_dim)) - self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim)) + self.event_shape = event_shape + self.event_size = int(torch.prod(torch.as_tensor(event_shape))) + self.register_buffer('loc', torch.zeros(self.event_size)) + self.register_buffer('covariance_matrix', torch.eye(self.event_size)) def get_device(self): return self.loc.device @@ -46,7 +34,7 @@ def base_log_prob(self, z: torch.Tensor): :param z: input tensor. :return: log probability of the input tensor. """ - zf = flatten_event(z, self.bijection.event_shape) + zf = flatten_event(z, self.event_shape) log_prob = self.base.log_prob(zf) return log_prob @@ -58,65 +46,11 @@ def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]): :return: tensor with shape sample_shape. """ z_flat = self.base.sample(sample_shape) - z = unflatten_event(z_flat, self.bijection.event_shape) + z = unflatten_event(z_flat, self.event_shape) return z - def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Transform the input x to the space of the base distribution. - - :param x: input tensor. - :param context: context tensor upon which the transformation is conditioned. - :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the - transformation. - """ - if context is not None: - assert context.shape[0] == x.shape[0] - context = context.to(self.loc) - z, log_det = self.bijection.forward(x.to(self.loc), context=context) - log_base = self.base_log_prob(z) - return z, log_base + log_det - - def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): - """ - Compute the logarithm of the probability density of input x according to the normalizing flow. - - :param x: input tensor. - :param context: context tensor. - :return: - """ - return self.forward_with_log_prob(x, context)[1] - - def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): - """ - Sample from the normalizing flow. - - If context given, sample n tensors for each context tensor. - Otherwise, sample n tensors. - - :param n: number of tensors to sample. - :param context: context tensor with shape c. - :param no_grad: if True, do not track gradients in the inverse pass. - :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. - """ - if context is not None: - z = self.base_sample(sample_shape=torch.Size((n, len(context)))) - context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape - assert z.shape[:2] == context.shape[:2] - else: - z = self.base_sample(sample_shape=torch.Size((n,))) - if no_grad: - z = z.detach() - with torch.no_grad(): - x, log_det = self.bijection.inverse(z, context=context) - else: - x, log_det = self.bijection.inverse(z, context=context) - x = x.to(self.loc) - - if return_log_prob: - log_prob = self.base_log_prob(z) + log_det - return x, log_prob - return x + def regularization(self): + return 0.0 def fit(self, x_train: torch.Tensor, @@ -155,10 +89,7 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ - self.bijection.train() - - # Compute the number of event dimensions - n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) + self.train() # Set the default batch size if batch_size is None: @@ -172,7 +103,7 @@ def fit(self, "training", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) # Process validation data @@ -184,7 +115,7 @@ def fit(self, "validation", batch_size=batch_size, shuffle=shuffle, - event_shape=self.bijection.event_shape + event_shape=self.event_shape ) best_val_loss = torch.inf @@ -195,10 +126,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] batch_context = batch_[2] if len(batch_) == 3 else None - batch_log_prob = self.log_prob(batch_x.to(self.loc), context=batch_context) - batch_weights = batch_weights.to(self.loc) + batch_log_prob = self.log_prob(batch_x.to(self.get_device()), context=batch_context) + batch_weights = batch_weights.to(self.get_device()) assert batch_log_prob.shape == batch_weights.shape, f"{batch_log_prob.shape = }, {batch_weights.shape = }" - batch_loss = -reduction(batch_log_prob * batch_weights) / n_event_dims + batch_loss = -reduction(batch_log_prob * batch_weights) / self.event_size return batch_loss @@ -210,8 +141,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) - if hasattr(self.bijection, 'regularization'): - train_loss += self.bijection.regularization() + train_loss += self.regularization() train_loss.backward() optimizer.step() @@ -233,8 +163,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for val_batch in val_loader: n_batch_data = len(val_batch[0]) val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data - if hasattr(self.bijection, 'regularization'): - val_loss += self.bijection.regularization() + val_loss += self.regularization() # Check if validation loss is the lowest so far if val_loss < best_val_loss: @@ -254,7 +183,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) - self.bijection.eval() + self.eval() def variational_fit(self, target_log_prob: callable, @@ -263,8 +192,8 @@ def variational_fit(self, n_samples: int = 1000, show_progress: bool = False): """ - Train a normalizing flow with stochastic variational inference. - Stochastic variational inference lets us train a normalizing flow using the unnormalized target log density + Train a distribution with stochastic variational inference. + Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. Refer to Rezende, Mohamed: "Variational Inference with Normalizing Flows" (2015) for more details @@ -279,15 +208,141 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ - iterator = tqdm(range(n_epochs), desc='Variational NF fit', disable=not show_progress) + iterator = tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) for _ in iterator: optimizer.zero_grad() flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob) - if hasattr(self.bijection, 'regularization'): - loss += self.bijection.regularization() + loss += self.regularization() loss.backward() optimizer.step() - iterator.set_postfix_str(f'Variational loss: {loss:.4f}') + iterator.set_postfix_str(f'Loss: {loss:.4f}') + + +class Flow(BaseFlow): + """ + Normalizing flow class. + + This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). + A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs. + """ + + def __init__(self, bijection: Bijection): + """ + + :param bijection: transformation component of the normalizing flow. + """ + super().__init__(event_shape=bijection.event_shape) + self.register_module('bijection', bijection) + + def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Transform the input x to the space of the base distribution. + + :param x: input tensor. + :param context: context tensor upon which the transformation is conditioned. + :return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the + transformation. + """ + if context is not None: + assert context.shape[0] == x.shape[0] + context = context.to(self.get_device()) + z, log_det = self.bijection.forward(x.to(self.get_device()), context=context) + log_base = self.base_log_prob(z) + return z, log_base + log_det + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Compute the logarithm of the probability density of input x according to the normalizing flow. + + :param x: input tensor. + :param context: context tensor. + :return: + """ + return self.forward_with_log_prob(x, context)[1] + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + """ + Sample from the normalizing flow. + + If context given, sample n tensors for each context tensor. + Otherwise, sample n tensors. + + :param n: number of tensors to sample. + :param context: context tensor with shape c. + :param no_grad: if True, do not track gradients in the inverse pass. + :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. + """ + if context is not None: + z = self.base_sample(sample_shape=torch.Size((n, len(context)))) + context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape + assert z.shape[:2] == context.shape[:2] + else: + z = self.base_sample(sample_shape=torch.Size((n,))) + if no_grad: + z = z.detach() + with torch.no_grad(): + x, log_det = self.bijection.inverse(z, context=context) + else: + x, log_det = self.bijection.inverse(z, context=context) + x = x.to(self.get_device()) + + if return_log_prob: + log_prob = self.base_log_prob(z) + log_det + return x, log_prob + return x + + def regularization(self): + if hasattr(self.bijection, 'regularization'): + return self.bijection.regularization() + else: + return 0.0 + + +class FlowMixture(BaseFlow): + def __init__(self, flows: List[Flow], weights: List[float]): + super().__init__(event_shape=flows[0].event_shape) + assert len(weights) == len(flows) + assert all([w > 0.0 for w in weights]) + assert np.isclose(sum(weights), 1.0) + + self.flows = flows + self.weights = torch.tensor(weights) + self.log_weights = torch.log(self.weights) + self.categorical_distribution = torch.distributions.Categorical(probs=self.weights) + + def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): + flow_log_probs = torch.stack([flow.log_prob(x, context=context) for flow in self.flows]) + # (n_flows, *batch_shape) + + batch_shape = flow_log_probs.shape[1:] + log_weights_reshaped = self.log_weights.view(-1, *([1] * len(batch_shape))) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # batch_shape + return log_prob + + def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + flow_samples = [] + flow_log_probs = [] + for flow in self.flows: + flow_x, flow_log_prob = flow.sample(n, context=context, no_grad=no_grad, return_log_prob=True) + flow_samples.append(flow_x) + flow_log_probs.append(flow_log_prob) + + with torch.no_grad(): + flow_samples = torch.stack(flow_samples) # (n_flows, n, *event_shape) + categorical_samples = self.categorical_distribution.sample(sample_shape=torch.Size((n,))) # (n,) + one_hot = torch.nn.functional.one_hot(categorical_samples, num_classes=len(flow_samples)).T # (n_flows, n) + samples = torch.sum(one_hot * flow_samples, dim=0) # (n, *event_shape) + + if return_log_prob: + flow_log_probs = torch.stack(flow_log_probs) # (n_flows, n) + log_weights_reshaped = self.log_weights[:, None] # (n_flows, 1) + log_prob = torch.logsumexp(log_weights_reshaped + flow_log_probs, dim=0) # (n,) + return samples, log_prob + else: + return samples + + def regularization(self): + return sum([flow.regularization() for flow in self.flows])