From 869dda40c99914841b3aa7154ca4fe92a24ac54b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 13 Nov 2023 16:01:30 -0800 Subject: [PATCH] Use small epsilon to handle numerics in LRS, move 2 into log as square for better stability --- .../transformers/spline/linear_rational.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 300dd9f..6e94ed1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -20,6 +20,7 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl self.min_bin_height = 1e-3 self.min_d = 1e-5 self.const = math.log(math.exp(1 - self.min_d) - 1) # to ensure identity initialization + self.eps = 1e-7 # Epsilon for numerical stability when computing forward/inverse @property def n_parameters(self) -> int: @@ -120,7 +121,7 @@ def forward_1d(self, x, h): ) log_det_phi_lt_lambda = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - 2 * torch.log(w_k * (lambda_k - phi) + w_m * phi) + - torch.log((w_k * (lambda_k - phi) + w_m * phi) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -130,7 +131,7 @@ def forward_1d(self, x, h): ) log_det_phi_gt_lambda = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) + - torch.log((w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -166,7 +167,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_lt_ym = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2) + - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) ) @@ -176,7 +177,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_gt_ym = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) + - torch.log((w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) )