Skip to content

Commit

Permalink
Add custom base distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 11, 2024
1 parent b00029e commit 19e360a
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
Empty file.
82 changes: 82 additions & 0 deletions normalizing_flows/base_distributions/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
import math

from normalizing_flows.utils import sum_except_batch


class DiagonalGaussian(torch.distributions.Distribution, nn.Module):
def __init__(self,
loc: torch.Tensor,
scale: torch.Tensor,
trainable_loc: bool = False,
trainable_scale: bool = False):
super().__init__(event_shape=loc.shape)
self.log_2_pi = math.log(2 * math.pi)
if trainable_loc:
self.register_parameter('loc', nn.Parameter(loc))
else:
self.register_buffer('loc', loc)

if trainable_scale:
self.register_parameter('log_scale', nn.Parameter(torch.log(scale)))
else:
self.register_buffer('log_scale', torch.log(scale))

@property
def scale(self):
return torch.exp(self.log_scale)

def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
noise = torch.randn(size=(*sample_shape, *self.event_shape)).to(self.loc)
sample_shape_mask = [None for _ in range(len(sample_shape))]
return self.loc[sample_shape_mask] + noise * self.scale[sample_shape_mask]

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
if len(value.shape) <= len(self.event_shape):
raise ValueError("Incorrect input shape")
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))]
loc = self.loc[sample_shape_mask]
scale = self.scale[sample_shape_mask]
log_scale = self.log_scale[sample_shape_mask]
elementwise_log_prob = -(0.5 * ((value - loc) / scale) ** 2 + 0.5 * self.log_2_pi + log_scale)
return sum_except_batch(elementwise_log_prob, self.event_shape)


class DenseGaussian(torch.distributions.Distribution, nn.Module):
def __init__(self,
loc: torch.Tensor,
cov: torch.Tensor,
trainable_loc: bool = False):
super().__init__(event_shape=loc.shape)
event_size = int(torch.prod(torch.as_tensor(self.event_shape)))
if cov.shape != (event_size, event_size):
raise ValueError("Incorrect covariance matrix shape")

self.log_2_pi = math.log(2 * math.pi)
if trainable_loc:
self.register_parameter('loc', nn.Parameter(loc))
else:
self.register_buffer('loc', loc)

cholesky = torch.cholesky(cov)
inverse_cholesky = torch.inverse(cholesky)
inverse_cov = inverse_cholesky.T @ inverse_cholesky

self.register_buffer('cholesky', cholesky)
self.register_buffer('inverse_cov', inverse_cov)
self.constant = -torch.sum(torch.log(torch.diag(cholesky))) - event_size / 2 * self.log_2_pi

def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
flat_noise = torch.randn(size=(*sample_shape, int(torch.prod(torch.as_tensor(self.event_shape))))).to(self.loc)
sample_shape_mask = [None for _ in range(len(sample_shape))]
loc = self.loc[sample_shape_mask]
return loc + (self.cholesky @ flat_noise).view_as(loc)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
# Without the determinant component
if len(value.shape) <= len(self.event_shape):
raise ValueError("Incorrect input shape")
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))]
diff = value - self.loc[sample_shape_mask]
return self.constant - 0.5 * torch.einsum('...i,ij,...j->...', diff, self.inverse_cov, diff)
63 changes: 63 additions & 0 deletions normalizing_flows/base_distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List

import torch
import torch.nn as nn

from normalizing_flows.base_distributions.gaussian import DiagonalGaussian, DenseGaussian
from normalizing_flows.utils import get_batch_shape


class Mixture(torch.distributions.Distribution, nn.Module):
def __init__(self,
components: List[torch.distributions.Distribution],
weights: torch.Tensor = None):
if weights is None:
weights = torch.ones(len(components)) / len(components)
super().__init__(event_shape=components[0].event_shape)
self.register_buffer('log_weights', torch.log(weights))
self.components = components
self.categorical = torch.distributions.Categorical(probs=weights)

def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
categories = self.categorical.sample(sample_shape)
outputs = torch.zeros(*sample_shape, *self.event_shape).to(self.log_weights)
for i, component in enumerate(self.components):
outputs[categories == i] = component.sample(sample_shape)[categories == i]
return outputs

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
# We are assuming all components are normalized
value = value.to(self.log_weights)
batch_shape = get_batch_shape(value, self.event_shape)
log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights)
for i in range(self.n_components):
log_probs[..., i] = self.components[i].log_prob(value)
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))]
return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1)


class DiagonalGaussianMixture(Mixture):
def __init__(self,
locs: torch.Tensor,
scales: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False,
trainable_scales: bool = False):
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
components.append(DiagonalGaussian(locs[i], scales[i], trainable_locs, trainable_scales))
super().__init__(components, weights)


class DenseGaussianMixture(Mixture):
def __init__(self,
locs: torch.Tensor,
covs: torch.Tensor,
weights: torch.Tensor = None,
trainable_locs: bool = False):
n_components, *event_shape = locs.shape
components = []
for i in range(n_components):
components.append(DenseGaussian(locs[i], covs[i], trainable_locs))
super().__init__(components, weights)

0 comments on commit 19e360a

Please sign in to comment.