From 0cb76987de184cc0b8a2b9a48e0c59c18574972f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 17 Nov 2024 13:40:05 +0100 Subject: [PATCH] Handle device in Sylvester flows --- .../bijections/finite/residual/sylvester.py | 20 +++++++++---------- torchflows/bijections/matrices.py | 9 ++++++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index fa43171..d15c00a 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -26,8 +26,8 @@ def __init__(self, self.b = nn.Parameter(torch.randn(m)) # q is implemented in subclasses - self.r = UpperTriangularInvertibleMatrix(n_dim=self.m) - self.r_tilde = UpperTriangularInvertibleMatrix(n_dim=self.m) + self.register_module('r', UpperTriangularInvertibleMatrix(n_dim=self.m)) + self.register_module('r_tilde', UpperTriangularInvertibleMatrix(n_dim=self.m)) @property def w(self): @@ -53,9 +53,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) z_flat = torch.flatten(z, start_dim=len(batch_shape)) - u = self.u.view(*([1] * len(batch_shape)), *self.u.shape) - w = self.w.view(*([1] * len(batch_shape)), *self.w.shape) - b = self.b.view(*([1] * len(batch_shape)), *self.b.shape) + u = self.u.view(*([1] * len(batch_shape)), *self.u.shape).to(z) + w = self.w.view(*([1] * len(batch_shape)), *self.w.shape).to(z) + b = self.b.view(*([1] * len(batch_shape)), *self.b.shape).to(z) wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) @@ -66,9 +66,9 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. ) wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) - diag = torch.zeros(size=(*batch_shape, self.m, self.m)) + diag = torch.zeros(size=(*batch_shape, self.m, self.m)).to(z) diag[..., range(self.m), range(self.m)] = self.h_deriv(wzpb) # (..., m, m) - _, log_det = torch.linalg.slogdet(torch.eye(self.m) + torch.einsum('...ij,...jk->...ik', diag, wu)) + _, log_det = torch.linalg.slogdet(torch.eye(self.m).to(z) + torch.einsum('...ij,...jk->...ik', diag, wu)) x = x.view(*batch_shape, *self.event_shape) @@ -78,13 +78,13 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class HouseholderSylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.q = HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m) + self.register_module('q', HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m)) class IdentitySylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.q = IdentityMatrix(n_dim=self.n_dim) + self.register_module('q', IdentityMatrix(n_dim=self.n_dim)) Sylvester = IdentitySylvester @@ -93,4 +93,4 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): class PermutationSylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.q = PermutationMatrix(n_dim=self.n_dim) + self.register_module('q', PermutationMatrix(n_dim=self.n_dim)) diff --git a/torchflows/bijections/matrices.py b/torchflows/bijections/matrices.py index 6053d68..4c0c980 100644 --- a/torchflows/bijections/matrices.py +++ b/torchflows/bijections/matrices.py @@ -47,8 +47,10 @@ def __init__(self, n_dim: int, unitriangular: bool = False, min_eigval: float = self.off_diagonal_indices = torch.tril_indices(self.n_dim, self.n_dim, -1) self.min_eigval = min_eigval + self.register_buffer('mat_zeros', torch.zeros(size=(self.n_dim, self.n_dim))) + def mat(self): - mat = torch.zeros(size=(self.n_dim, self.n_dim)) + mat = self.mat_zeros mat[range(self.n_dim), range(self.n_dim)] = self.compute_diagonal_elements() mat[self.off_diagonal_indices[0], self.off_diagonal_indices[1]] = self.off_diagonal_elements return mat @@ -94,7 +96,7 @@ def __init__(self, n_dim: int, n_factors: int = None): def mat(self): v_outer = torch.einsum('fi,fj->fij', self.v, self.v) v_norms_squared = torch.linalg.norm(self.v, dim=1).view(-1, 1, 1) ** 2 - h = (torch.eye(self.n_dim)[None] - 2 * (v_outer / v_norms_squared)) + h = (torch.eye(self.n_dim)[None].to(v_outer) - 2 * (v_outer / v_norms_squared)) return torch.linalg.multi_dot(list(h)) def log_det(self): @@ -107,9 +109,10 @@ def solve(self, z): class IdentityMatrix(InvertibleMatrix): def __init__(self, n_dim: int, **kwargs): super().__init__(n_dim, **kwargs) + self.register_buffer('_mat', torch.eye(self.n_dim)) def mat(self): - return torch.eye(self.n_dim) + return self._mat def log_det(self): return 0.0