Skip to content

Commit

Permalink
Update radial flow log determinant
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 11, 2024
1 parent deb235b commit f0b3279
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
13 changes: 13 additions & 0 deletions test/test_radial_numerical_stability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

from torchflows import Radial


def test_exhaustive():
torch.manual_seed(0)
event_shape = (1000,)
bijection = Radial(event_shape=event_shape)
z = torch.randn(size=(5000, *event_shape)) ** 2
x, log_det_inverse = bijection.inverse(z)
assert torch.isfinite(x).all()
assert torch.isfinite(log_det_inverse).all()
30 changes: 16 additions & 14 deletions torchflows/bijections/finite/residual/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
self.unconstrained_alpha = nn.Parameter(torch.randn(size=()))
self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,)))

self.eps = 1e-6

@property
def alpha(self):
return softplus(self.unconstrained_alpha)

def h(self, z):
batch_shape = z.shape[:-1]
z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape)
r = torch.abs(z - z0)
return 1 / (self.alpha + r)
r = torch.sqrt(torch.square(z - z0))
return 1 / (self.alpha + r + self.eps)

def h_deriv(self, z):
batch_shape = z.shape[:-1]
Expand All @@ -37,24 +39,24 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.
raise NotImplementedError

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
# Flatten event
batch_shape = get_batch_shape(z, self.event_shape)
z = z.view(*batch_shape, self.n_dim)

# Compute auxiliary variables
z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape)
r = torch.sqrt(torch.square(z - z0))
h = 1 / (self.alpha + r + self.eps)

# Compute transformed point
x = z + self.beta * self.h(z) * (z - z0)
x = z + self.beta * h * (z - z0)

# Compute determinant of the Jacobian
h_val = self.h(z)
r = torch.abs(z - z0)
beta_times_h_val = self.beta * h_val
# det = (1 + self.beta * h_val) ** (self.n_dim - 1) * (1 + self.beta * h_val + self.h_deriv(z) * r)
# log_det = torch.log(torch.abs(det))
# log_det = (self.n_dim - 1) * torch.log1p(beta_times_h_val) + torch.log(1 + beta_times_h_val + self.h_deriv(z) * r)
log_det = torch.abs(torch.add(
(self.n_dim - 1) * torch.log1p(beta_times_h_val),
torch.log(1 + beta_times_h_val + self.h_deriv(z) * r)
)).sum(dim=-1)
x = x.view(*batch_shape, *self.event_shape)
log_det = torch.add(
torch.log1p(self.alpha * self.beta / h ** 2),
torch.log1p(self.beta / h) * (self.n_dim - 1)
).sum(dim=-1)

# Unflatten event
x = x.view(*batch_shape, *self.event_shape)
return x, log_det

0 comments on commit f0b3279

Please sign in to comment.