diff --git a/bnpm/similarity.py b/bnpm/similarity.py index 2df2ebe..48cf3cb 100644 --- a/bnpm/similarity.py +++ b/bnpm/similarity.py @@ -237,11 +237,15 @@ def orthogonalize(v1, v2, method='OLS', device='cpu', thresh_EVR_PCA=1e-15): def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True): """ - Orthogonalizes a matrix by finding the nearest orthogonal matrix. Nearest is - defined as solving the Procrustes problem via minimizing the Frobenius norm - of the difference between the original input matrix and the orthogonal - matrix. \n - X_orth = argmin ||X - X_orth||_F + Orthogonalizes a matrix by finding the nearest orthonormal matrix to an + input ``X``. 'Nearest' is defined as solving the orthogonal Procrustes + problem via minimizing the Frobenius norm of the difference between the + original input matrix and a orthonormal matrix that spans the same + dimensions as ``X``. For the initial orthonormal matrix, we use a centered + ``Q`` from a QR decomposition of ``X``. \n + ``X_orth = argmin: X_orth for ||X - X_orth||_F`` , where ``X_orth = Q @ R`` , where + ``R`` is a rotation matrix derived from solving the orthogonal Procrustes + problem for ``R = argmin: R for ||X - Q @ R||_F``. \n RH 2024 @@ -256,19 +260,21 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True Returns: (ndarray): - X_orth: orthogonalized matrix. shape: (n_samples, n_features) + X_orth: + Orthogonalized matrix. shape: (n_samples, n_features) """ if isinstance(X, np.ndarray): - op, qr = scipy.linalg.orthogonal_procrustes, np.linalg.qr + op, qr, svd = scipy.linalg.orthogonal_procrustes, np.linalg.qr, np.linalg.svd elif isinstance(X, torch.Tensor): - op, qr = torch_helpers.orthogonal_procrustes, torch.linalg.qr + op, qr, svd = torch_helpers.orthogonal_procrustes, torch.linalg.qr, torch.linalg.svd if center: X = X - X.mean(axis=0, keepdims=True) Q = qr(X)[0] - w, scale = op(Q, X) - return Q @ w + Q -= Q.mean(axis=0, keepdims=True) + R, scale = op(Q, X) + return Q @ R @njit