From f0b32794ce0ed4a4652041a0ca5dccf1adeea3c4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 11 Nov 2024 19:30:50 +0100 Subject: [PATCH] Update radial flow log determinant --- test/test_radial_numerical_stability.py | 13 ++++++++ .../bijections/finite/residual/radial.py | 30 ++++++++++--------- 2 files changed, 29 insertions(+), 14 deletions(-) create mode 100644 test/test_radial_numerical_stability.py diff --git a/test/test_radial_numerical_stability.py b/test/test_radial_numerical_stability.py new file mode 100644 index 0000000..e7b3dd4 --- /dev/null +++ b/test/test_radial_numerical_stability.py @@ -0,0 +1,13 @@ +import torch + +from torchflows import Radial + + +def test_exhaustive(): + torch.manual_seed(0) + event_shape = (1000,) + bijection = Radial(event_shape=event_shape) + z = torch.randn(size=(5000, *event_shape)) ** 2 + x, log_det_inverse = bijection.inverse(z) + assert torch.isfinite(x).all() + assert torch.isfinite(log_det_inverse).all() diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 47eb1c3..c667e30 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -17,6 +17,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): self.unconstrained_alpha = nn.Parameter(torch.randn(size=())) self.z0 = nn.Parameter(torch.randn(size=(self.n_dim,))) + self.eps = 1e-6 + @property def alpha(self): return softplus(self.unconstrained_alpha) @@ -24,8 +26,8 @@ def alpha(self): def h(self, z): batch_shape = z.shape[:-1] z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - r = torch.abs(z - z0) - return 1 / (self.alpha + r) + r = torch.sqrt(torch.square(z - z0)) + return 1 / (self.alpha + r + self.eps) def h_deriv(self, z): batch_shape = z.shape[:-1] @@ -37,24 +39,24 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. raise NotImplementedError def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + # Flatten event batch_shape = get_batch_shape(z, self.event_shape) z = z.view(*batch_shape, self.n_dim) + + # Compute auxiliary variables z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) + r = torch.sqrt(torch.square(z - z0)) + h = 1 / (self.alpha + r + self.eps) # Compute transformed point - x = z + self.beta * self.h(z) * (z - z0) + x = z + self.beta * h * (z - z0) # Compute determinant of the Jacobian - h_val = self.h(z) - r = torch.abs(z - z0) - beta_times_h_val = self.beta * h_val - # det = (1 + self.beta * h_val) ** (self.n_dim - 1) * (1 + self.beta * h_val + self.h_deriv(z) * r) - # log_det = torch.log(torch.abs(det)) - # log_det = (self.n_dim - 1) * torch.log1p(beta_times_h_val) + torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - log_det = torch.abs(torch.add( - (self.n_dim - 1) * torch.log1p(beta_times_h_val), - torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - )).sum(dim=-1) - x = x.view(*batch_shape, *self.event_shape) + log_det = torch.add( + torch.log1p(self.alpha * self.beta / h ** 2), + torch.log1p(self.beta / h) * (self.n_dim - 1) + ).sum(dim=-1) + # Unflatten event + x = x.view(*batch_shape, *self.event_shape) return x, log_det