Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 8, 2023
1 parent b04691a commit e7d5b92
Showing 1 changed file with 53 additions and 9 deletions.
62 changes: 53 additions & 9 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union, Tuple

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
Expand All @@ -9,27 +11,60 @@


class Flow(nn.Module):
"""
Normalizing flow class.
This class represents a bijective transformation of a standard Gaussian distribution (the base distribution).
A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs.
"""
def __init__(self, bijection: Bijection):
"""
:param bijection: transformation component of the normalizing flow.
"""
super().__init__()
self.register_module('bijection', bijection)
self.register_buffer('loc', torch.zeros(self.bijection.n_dim))
self.register_buffer('covariance_matrix', torch.eye(self.bijection.n_dim))

@property
def base(self):
def base(self) -> torch.distributions.Distribution:
"""
:return: base distribution of the normalizing flow.
"""
return torch.distributions.MultivariateNormal(loc=self.loc, covariance_matrix=self.covariance_matrix)

def base_log_prob(self, z):
def base_log_prob(self, z: torch.Tensor):
"""
Compute the log probability of input z under the base distribution.
:param z: input tensor.
:return: log probability of the input tensor.
"""
zf = flatten_event(z, self.bijection.event_shape)
log_prob = self.base.log_prob(zf)
return log_prob

def base_sample(self, sample_shape):
def base_sample(self, sample_shape: Union[torch.Size, Tuple[int, ...]]):
"""
Sample from the base distribution.
:param sample_shape: desired shape of sampled tensor.
:return: tensor with shape sample_shape.
"""
z_flat = self.base.sample(sample_shape)
z = unflatten_event(z_flat, self.bijection.event_shape)
return z

def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None):
"""
Transform the input x to the space of the base distribution.
:param x: input tensor.
:param context: context tensor upon which the transformation is conditioned.
:return: transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the
transformation.
"""
if context is not None:
assert context.shape[0] == x.shape[0]
context = context.to(self.loc)
Expand All @@ -38,17 +73,26 @@ def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None):
return z, log_base + log_det

def log_prob(self, x: torch.Tensor, context: torch.Tensor = None):
"""
Compute the logarithm of the probability density of input x according to the normalizing flow.
:param x: input tensor.
:param context: context tensor.
:return:
"""
return self.forward_with_log_prob(x, context)[1]

def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False):
"""
If context given, sample n vectors for each context vector.
Otherwise, sample n vectors.
Sample from the normalizing flow.
:param n:
:param context:
:param no_grad:
:return:
If context given, sample n tensors for each context tensor.
Otherwise, sample n tensors.
:param n: number of tensors to sample.
:param context: context tensor with shape c.
:param no_grad: if True, do not track gradients in the inverse pass.
:return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given.
"""
if context is not None:
z = self.base_sample(sample_shape=torch.Size((n, len(context))))
Expand Down

0 comments on commit e7d5b92

Please sign in to comment.