diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index c667e30..4351ea4 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -23,18 +23,6 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): def alpha(self): return softplus(self.unconstrained_alpha) - def h(self, z): - batch_shape = z.shape[:-1] - z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - r = torch.sqrt(torch.square(z - z0)) - return 1 / (self.alpha + r + self.eps) - - def h_deriv(self, z): - batch_shape = z.shape[:-1] - z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - sign = (-1.0) ** torch.less(z, z0).float() - return -(self.h(z) ** 2) * sign * z - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError