diff --git a/normalizing_flows/base_distributions/__init__.py b/normalizing_flows/base_distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/base_distributions/gaussian.py b/normalizing_flows/base_distributions/gaussian.py new file mode 100644 index 0000000..ae0baab --- /dev/null +++ b/normalizing_flows/base_distributions/gaussian.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import math + +from normalizing_flows.utils import sum_except_batch + + +class DiagonalGaussian(torch.distributions.Distribution, nn.Module): + def __init__(self, + loc: torch.Tensor, + scale: torch.Tensor, + trainable_loc: bool = False, + trainable_scale: bool = False): + super().__init__(event_shape=loc.shape) + self.log_2_pi = math.log(2 * math.pi) + if trainable_loc: + self.register_parameter('loc', nn.Parameter(loc)) + else: + self.register_buffer('loc', loc) + + if trainable_scale: + self.register_parameter('log_scale', nn.Parameter(torch.log(scale))) + else: + self.register_buffer('log_scale', torch.log(scale)) + + @property + def scale(self): + return torch.exp(self.log_scale) + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + noise = torch.randn(size=(*sample_shape, *self.event_shape)).to(self.loc) + sample_shape_mask = [None for _ in range(len(sample_shape))] + return self.loc[sample_shape_mask] + noise * self.scale[sample_shape_mask] + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + if len(value.shape) <= len(self.event_shape): + raise ValueError("Incorrect input shape") + sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] + loc = self.loc[sample_shape_mask] + scale = self.scale[sample_shape_mask] + log_scale = self.log_scale[sample_shape_mask] + elementwise_log_prob = -(0.5 * ((value - loc) / scale) ** 2 + 0.5 * self.log_2_pi + log_scale) + return sum_except_batch(elementwise_log_prob, self.event_shape) + + +class DenseGaussian(torch.distributions.Distribution, nn.Module): + def __init__(self, + loc: torch.Tensor, + cov: torch.Tensor, + trainable_loc: bool = False): + super().__init__(event_shape=loc.shape) + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + if cov.shape != (event_size, event_size): + raise ValueError("Incorrect covariance matrix shape") + + self.log_2_pi = math.log(2 * math.pi) + if trainable_loc: + self.register_parameter('loc', nn.Parameter(loc)) + else: + self.register_buffer('loc', loc) + + cholesky = torch.cholesky(cov) + inverse_cholesky = torch.inverse(cholesky) + inverse_cov = inverse_cholesky.T @ inverse_cholesky + + self.register_buffer('cholesky', cholesky) + self.register_buffer('inverse_cov', inverse_cov) + self.constant = -torch.sum(torch.log(torch.diag(cholesky))) - event_size / 2 * self.log_2_pi + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + flat_noise = torch.randn(size=(*sample_shape, int(torch.prod(torch.as_tensor(self.event_shape))))).to(self.loc) + sample_shape_mask = [None for _ in range(len(sample_shape))] + loc = self.loc[sample_shape_mask] + return loc + (self.cholesky @ flat_noise).view_as(loc) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + # Without the determinant component + if len(value.shape) <= len(self.event_shape): + raise ValueError("Incorrect input shape") + sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] + diff = value - self.loc[sample_shape_mask] + return self.constant - 0.5 * torch.einsum('...i,ij,...j->...', diff, self.inverse_cov, diff) diff --git a/normalizing_flows/base_distributions/mixture.py b/normalizing_flows/base_distributions/mixture.py new file mode 100644 index 0000000..b4ac23c --- /dev/null +++ b/normalizing_flows/base_distributions/mixture.py @@ -0,0 +1,63 @@ +from typing import List + +import torch +import torch.nn as nn + +from normalizing_flows.base_distributions.gaussian import DiagonalGaussian, DenseGaussian +from normalizing_flows.utils import get_batch_shape + + +class Mixture(torch.distributions.Distribution, nn.Module): + def __init__(self, + components: List[torch.distributions.Distribution], + weights: torch.Tensor = None): + if weights is None: + weights = torch.ones(len(components)) / len(components) + super().__init__(event_shape=components[0].event_shape) + self.register_buffer('log_weights', torch.log(weights)) + self.components = components + self.categorical = torch.distributions.Categorical(probs=weights) + + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + categories = self.categorical.sample(sample_shape) + outputs = torch.zeros(*sample_shape, *self.event_shape).to(self.log_weights) + for i, component in enumerate(self.components): + outputs[categories == i] = component.sample(sample_shape)[categories == i] + return outputs + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + # We are assuming all components are normalized + value = value.to(self.log_weights) + batch_shape = get_batch_shape(value, self.event_shape) + log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights) + for i in range(self.n_components): + log_probs[..., i] = self.components[i].log_prob(value) + sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] + return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1) + + +class DiagonalGaussianMixture(Mixture): + def __init__(self, + locs: torch.Tensor, + scales: torch.Tensor, + weights: torch.Tensor = None, + trainable_locs: bool = False, + trainable_scales: bool = False): + n_components, *event_shape = locs.shape + components = [] + for i in range(n_components): + components.append(DiagonalGaussian(locs[i], scales[i], trainable_locs, trainable_scales)) + super().__init__(components, weights) + + +class DenseGaussianMixture(Mixture): + def __init__(self, + locs: torch.Tensor, + covs: torch.Tensor, + weights: torch.Tensor = None, + trainable_locs: bool = False): + n_components, *event_shape = locs.shape + components = [] + for i in range(n_components): + components.append(DenseGaussian(locs[i], covs[i], trainable_locs)) + super().__init__(components, weights)