diff --git a/bnpm/similarity.py b/bnpm/similarity.py index da0a8dd..0f30a01 100644 --- a/bnpm/similarity.py +++ b/bnpm/similarity.py @@ -234,7 +234,7 @@ def orthogonalize(v1, v2, method='OLS', device='cpu', thresh_EVR_PCA=1e-15): return v1_orth, EVR, EVR_total, pca_dict -def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]): +def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True): """ Orthogonalizes a matrix by finding the nearest orthogonal matrix. Nearest is defined as solving the Procrustes problem via minimizing the Frobenius norm @@ -248,6 +248,8 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]): Args: X (ndarray): Matrix to orthogonalize. shape: (n_samples, n_features) + center (bool): + Whether to center the matrix. Default is True. Returns: (ndarray): @@ -258,6 +260,9 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]): elif isinstance(X, torch.Tensor): op, qr = torch_helpers.orthogonal_procrustes, torch.linalg.qr + if center: + X = X - X.mean(axis=0, keepdims=True) + Q = qr(X)[0] w, scale = op(Q, X) return w @ scale