diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index b3243571..3144fa76 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -1267,17 +1267,17 @@ def cat_rows( R = self.root_inv_decomposition().root.to_dense() # RR^T = A^{-1} (this is fast if L is triangular) lower_left = B_ @ R # F = BR schur = D - lower_left.matmul(lower_left.mT) # GG^T = new_mat - FF^T - schur_root = to_linear_operator(schur).root_decomposition().root.to_dense() # G = (new_mat - FF^T)^{1/2} + schur_root = to_linear_operator(schur).root_decomposition().root # G = (new_mat - FF^T)^{1/2} # Form new root matrix num_fant = schur_root.size(-2) new_root = torch.zeros(*batch_shape, m + num_fant, n + num_fant, device=E.device, dtype=E.dtype) new_root[..., :m, :n] = E.to_dense() new_root[..., m:, : lower_left.shape[-1]] = lower_left - new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root + new_root[..., m:, n : (n + schur_root.shape[-1])] = schur_root.to_dense() if generate_inv_roots: if isinstance(E, TriangularLinearOperator) and isinstance(schur_root, TriangularLinearOperator): - # make sure these are actually upper triangular + # make sure these are actually lower triangular if getattr(E, "upper", False) or getattr(schur_root, "upper", False): raise NotImplementedError # in this case we know new_root is triangular as well