-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b00029e
commit 19e360a
Showing
3 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
from normalizing_flows.utils import sum_except_batch | ||
|
||
|
||
class DiagonalGaussian(torch.distributions.Distribution, nn.Module): | ||
def __init__(self, | ||
loc: torch.Tensor, | ||
scale: torch.Tensor, | ||
trainable_loc: bool = False, | ||
trainable_scale: bool = False): | ||
super().__init__(event_shape=loc.shape) | ||
self.log_2_pi = math.log(2 * math.pi) | ||
if trainable_loc: | ||
self.register_parameter('loc', nn.Parameter(loc)) | ||
else: | ||
self.register_buffer('loc', loc) | ||
|
||
if trainable_scale: | ||
self.register_parameter('log_scale', nn.Parameter(torch.log(scale))) | ||
else: | ||
self.register_buffer('log_scale', torch.log(scale)) | ||
|
||
@property | ||
def scale(self): | ||
return torch.exp(self.log_scale) | ||
|
||
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: | ||
noise = torch.randn(size=(*sample_shape, *self.event_shape)).to(self.loc) | ||
sample_shape_mask = [None for _ in range(len(sample_shape))] | ||
return self.loc[sample_shape_mask] + noise * self.scale[sample_shape_mask] | ||
|
||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | ||
if len(value.shape) <= len(self.event_shape): | ||
raise ValueError("Incorrect input shape") | ||
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] | ||
loc = self.loc[sample_shape_mask] | ||
scale = self.scale[sample_shape_mask] | ||
log_scale = self.log_scale[sample_shape_mask] | ||
elementwise_log_prob = -(0.5 * ((value - loc) / scale) ** 2 + 0.5 * self.log_2_pi + log_scale) | ||
return sum_except_batch(elementwise_log_prob, self.event_shape) | ||
|
||
|
||
class DenseGaussian(torch.distributions.Distribution, nn.Module): | ||
def __init__(self, | ||
loc: torch.Tensor, | ||
cov: torch.Tensor, | ||
trainable_loc: bool = False): | ||
super().__init__(event_shape=loc.shape) | ||
event_size = int(torch.prod(torch.as_tensor(self.event_shape))) | ||
if cov.shape != (event_size, event_size): | ||
raise ValueError("Incorrect covariance matrix shape") | ||
|
||
self.log_2_pi = math.log(2 * math.pi) | ||
if trainable_loc: | ||
self.register_parameter('loc', nn.Parameter(loc)) | ||
else: | ||
self.register_buffer('loc', loc) | ||
|
||
cholesky = torch.cholesky(cov) | ||
inverse_cholesky = torch.inverse(cholesky) | ||
inverse_cov = inverse_cholesky.T @ inverse_cholesky | ||
|
||
self.register_buffer('cholesky', cholesky) | ||
self.register_buffer('inverse_cov', inverse_cov) | ||
self.constant = -torch.sum(torch.log(torch.diag(cholesky))) - event_size / 2 * self.log_2_pi | ||
|
||
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: | ||
flat_noise = torch.randn(size=(*sample_shape, int(torch.prod(torch.as_tensor(self.event_shape))))).to(self.loc) | ||
sample_shape_mask = [None for _ in range(len(sample_shape))] | ||
loc = self.loc[sample_shape_mask] | ||
return loc + (self.cholesky @ flat_noise).view_as(loc) | ||
|
||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | ||
# Without the determinant component | ||
if len(value.shape) <= len(self.event_shape): | ||
raise ValueError("Incorrect input shape") | ||
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] | ||
diff = value - self.loc[sample_shape_mask] | ||
return self.constant - 0.5 * torch.einsum('...i,ij,...j->...', diff, self.inverse_cov, diff) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from normalizing_flows.base_distributions.gaussian import DiagonalGaussian, DenseGaussian | ||
from normalizing_flows.utils import get_batch_shape | ||
|
||
|
||
class Mixture(torch.distributions.Distribution, nn.Module): | ||
def __init__(self, | ||
components: List[torch.distributions.Distribution], | ||
weights: torch.Tensor = None): | ||
if weights is None: | ||
weights = torch.ones(len(components)) / len(components) | ||
super().__init__(event_shape=components[0].event_shape) | ||
self.register_buffer('log_weights', torch.log(weights)) | ||
self.components = components | ||
self.categorical = torch.distributions.Categorical(probs=weights) | ||
|
||
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: | ||
categories = self.categorical.sample(sample_shape) | ||
outputs = torch.zeros(*sample_shape, *self.event_shape).to(self.log_weights) | ||
for i, component in enumerate(self.components): | ||
outputs[categories == i] = component.sample(sample_shape)[categories == i] | ||
return outputs | ||
|
||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: | ||
# We are assuming all components are normalized | ||
value = value.to(self.log_weights) | ||
batch_shape = get_batch_shape(value, self.event_shape) | ||
log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights) | ||
for i in range(self.n_components): | ||
log_probs[..., i] = self.components[i].log_prob(value) | ||
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] | ||
return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1) | ||
|
||
|
||
class DiagonalGaussianMixture(Mixture): | ||
def __init__(self, | ||
locs: torch.Tensor, | ||
scales: torch.Tensor, | ||
weights: torch.Tensor = None, | ||
trainable_locs: bool = False, | ||
trainable_scales: bool = False): | ||
n_components, *event_shape = locs.shape | ||
components = [] | ||
for i in range(n_components): | ||
components.append(DiagonalGaussian(locs[i], scales[i], trainable_locs, trainable_scales)) | ||
super().__init__(components, weights) | ||
|
||
|
||
class DenseGaussianMixture(Mixture): | ||
def __init__(self, | ||
locs: torch.Tensor, | ||
covs: torch.Tensor, | ||
weights: torch.Tensor = None, | ||
trainable_locs: bool = False): | ||
n_components, *event_shape = locs.shape | ||
components = [] | ||
for i in range(n_components): | ||
components.append(DenseGaussian(locs[i], covs[i], trainable_locs)) | ||
super().__init__(components, weights) |