Skip to content

Commit

Permalink
Fix bug in cat_rows which prevented efficient root inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
naefjo committed Jan 23, 2024
1 parent 574d2c5 commit 6ffd165
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,17 +1267,17 @@ def cat_rows(
R = self.root_inv_decomposition().root.to_dense() # RR^T = A^{-1} (this is fast if L is triangular)
lower_left = B_ @ R # F = BR
schur = D - lower_left.matmul(lower_left.mT) # GG^T = new_mat - FF^T
schur_root = to_linear_operator(schur).root_decomposition().root.to_dense() # G = (new_mat - FF^T)^{1/2}
schur_root = to_linear_operator(schur).root_decomposition().root # G = (new_mat - FF^T)^{1/2}

# Form new root matrix
num_fant = schur_root.size(-2)
new_root = torch.zeros(*batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype)
new_root[..., :m, :n] = E.to_dense()
new_root[..., m:, : lower_left.shape[-1]] = lower_left
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root
new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root.to_dense()
if generate_inv_roots:
if isinstance(E, TriangularLinearOperator) and isinstance(schur_root, TriangularLinearOperator):
# make sure these are actually upper triangular
# make sure these are actually lower triangular
if getattr(E, "upper", False) or getattr(schur_root, "upper", False):
raise NotImplementedError
# in this case we know new_root is triangular as well
Expand Down

0 comments on commit 6ffd165

Please sign in to comment.