Skip to content

Commit

Permalink
Refactor orthogonalize_matrix_nearest function to include center para…
Browse files Browse the repository at this point in the history
…meter and improve code readability
  • Loading branch information
RichieHakim committed Apr 14, 2024
1 parent f2c6f11 commit f3007b0
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions bnpm/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,15 @@ def orthogonalize(v1, v2, method='OLS', device='cpu', thresh_EVR_PCA=1e-15):

def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True):
"""
Orthogonalizes a matrix by finding the nearest orthogonal matrix. Nearest is
defined as solving the Procrustes problem via minimizing the Frobenius norm
of the difference between the original input matrix and the orthogonal
matrix. \n
X_orth = argmin ||X - X_orth||_F
Orthogonalizes a matrix by finding the nearest orthonormal matrix to an
input ``X``. 'Nearest' is defined as solving the orthogonal Procrustes
problem via minimizing the Frobenius norm of the difference between the
original input matrix and a orthonormal matrix that spans the same
dimensions as ``X``. For the initial orthonormal matrix, we use a centered
``Q`` from a QR decomposition of ``X``. \n
``X_orth = argmin: X_orth for ||X - X_orth||_F`` , where ``X_orth = Q @ R`` , where
``R`` is a rotation matrix derived from solving the orthogonal Procrustes
problem for ``R = argmin: R for ||X - Q @ R||_F``. \n
RH 2024
Expand All @@ -256,19 +260,21 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True
Returns:
(ndarray):
X_orth: orthogonalized matrix. shape: (n_samples, n_features)
X_orth:
Orthogonalized matrix. shape: (n_samples, n_features)
"""
if isinstance(X, np.ndarray):
op, qr = scipy.linalg.orthogonal_procrustes, np.linalg.qr
op, qr, svd = scipy.linalg.orthogonal_procrustes, np.linalg.qr, np.linalg.svd
elif isinstance(X, torch.Tensor):
op, qr = torch_helpers.orthogonal_procrustes, torch.linalg.qr
op, qr, svd = torch_helpers.orthogonal_procrustes, torch.linalg.qr, torch.linalg.svd

if center:
X = X - X.mean(axis=0, keepdims=True)

Q = qr(X)[0]
w, scale = op(Q, X)
return Q @ w
Q -= Q.mean(axis=0, keepdims=True)
R, scale = op(Q, X)
return Q @ R


@njit
Expand Down

0 comments on commit f3007b0

Please sign in to comment.