Skip to content

Commit

Permalink
Refactor orthogonalize_matrix_nearest function to include center para…
Browse files Browse the repository at this point in the history
…meter
  • Loading branch information
RichieHakim committed Apr 13, 2024
1 parent ae5b03e commit fe18b9a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion bnpm/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def orthogonalize(v1, v2, method='OLS', device='cpu', thresh_EVR_PCA=1e-15):
return v1_orth, EVR, EVR_total, pca_dict


def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]):
def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor], center=True):
"""
Orthogonalizes a matrix by finding the nearest orthogonal matrix. Nearest is
defined as solving the Procrustes problem via minimizing the Frobenius norm
Expand All @@ -248,6 +248,8 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]):
Args:
X (ndarray):
Matrix to orthogonalize. shape: (n_samples, n_features)
center (bool):
Whether to center the matrix. Default is True.
Returns:
(ndarray):
Expand All @@ -258,6 +260,9 @@ def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]):
elif isinstance(X, torch.Tensor):
op, qr = torch_helpers.orthogonal_procrustes, torch.linalg.qr

if center:
X = X - X.mean(axis=0, keepdims=True)

Q = qr(X)[0]
w, scale = op(Q, X)
return w @ scale
Expand Down

0 comments on commit fe18b9a

Please sign in to comment.