Skip to content

Commit

Permalink
New ZCA, torch orthogonal_procrustes, and orthogonalize_nearest funct…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
RichieHakim committed Apr 13, 2024
1 parent d742add commit 90f0e31
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 27 deletions.
71 changes: 64 additions & 7 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
import math
import gc
import copy

import sklearn.decomposition
import numpy as np
import scipy.interpolate
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

import copy
from tqdm.notebook import tqdm

# import cuml
# import cuml.decomposition
# import cupy

import gc

from tqdm.notebook import tqdm

###########################
########## PCA ############
Expand Down Expand Up @@ -597,6 +594,66 @@ def dimensionality_pca(
return interp(ev)


def ZCA_whiten(
X: Union[np.ndarray, torch.Tensor],
V: Union[np.ndarray, torch.Tensor],
S: Union[np.ndarray, torch.Tensor],
eps: float = 1e-5,
):
"""
ZCA whitening of data.
See: https://jermwatt.github.io/control-notes/posts/zca_sphereing/ZCA_Sphereing.html
RH 2024
Args:
X (np.ndarray or torch.Tensor):
Data to be whitened. \n
Shape (n_samples, n_features).
V (np.ndarray or torch.Tensor):
The principal components / eigenvectors. \n
You can use PCA.components_ from sklearn.decomposition.PCA or
PCA.components_ above. \n
Shape (n_features, n_components).
S (np.ndarray or torch.Tensor):
The singular values / eigenvalues. \n
You can use PCA.singular_values_ from sklearn.decomposition.PCA or
PCA.singular_values_ above. \n
Shape (n_components,).
eps (float):
Small value to prevent division by zero.
Returns:
X_zca (np.ndarray or torch.Tensor):
The ZCA whitened data. \n
Shape (n_samples, n_features).
Demo:
..code-block:: python
X = np.random.randn(100, 10)
pca = PCA(n_components=5)
pca.fit(X)
X_zca = ZCA_whiten(
X=X,
V=pca.components_,
S=pca.singular_values_,
eps=1e-5,
)
"""
if isinstance(X, np.ndarray):
mean, sqrt, diag = np.mean, np.sqrt, np.diag
elif isinstance(X, torch.Tensor):
mean, sqrt, diag = torch.mean, torch.sqrt, torch.diag

X = X - mean(X, axis=0, keepdims=True)

D_inv = diag(1.0 / (S + eps))
W_zca = V.T @ D_inv @ V
X_zca = X @ W_zca

return X_zca


#######################################
########## Incremental PCA ############
#######################################
Expand Down
54 changes: 34 additions & 20 deletions bnpm/similarity.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
'''
Table of Contents
Functions and Interdependencies:
proj
orthogonalize
- proj
OLS
EV
pairwise_similarity
best_permutation
- pairwise_similarity
self_similarity_pairwise
- best_permutation
'''
import copy
import time
from functools import partial

import numpy as np
import scipy.optimize
# import sklearn.decomposition
from numba import njit, prange, jit
import torch
from tqdm import tqdm

from . import indexing
from . import indexing, torch_helpers

import copy
import time
from functools import partial

def proj(v1, v2):
'''
Expand Down Expand Up @@ -249,6 +234,35 @@ def orthogonalize(v1, v2, method='OLS', device='cpu', thresh_EVR_PCA=1e-15):
return v1_orth, EVR, EVR_total, pca_dict


def orthogonalize_matrix_nearest(X: Union[np.ndarray, torch.Tensor]):
"""
Orthogonalizes a matrix by finding the nearest orthogonal matrix. Nearest is
defined as solving the Procrustes problem via minimizing the Frobenius norm
of the difference between the original input matrix and the orthogonal
matrix. \n
X_orth = argmin ||X - X_orth||_F
Note: The solution to this problem is generally equivalent to the ZCA
whitening of PCA(X).
Args:
X (ndarray):
Matrix to orthogonalize. shape: (n_samples, n_features)
Returns:
(ndarray):
X_orth: orthogonalized matrix. shape: (n_samples, n_features)
"""
if isinstance(X, np.ndarray):
op, qr = scipy.linalg.orthogonal_procrustes, np.linalg.qr
elif isinstance(X, torch.Tensor):
op, qr = torch_helpers.orthogonal_procrustes, torch.linalg.qr

Q = qr(X)[0]
w, scale = op(Q, X)
return w @ scale


@njit
def pair_orth_helper(v1, v2):
"""
Expand Down
51 changes: 51 additions & 0 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,57 @@ def slice_along_dim(
return X[tuple(slices)]


def orthogonal_procrustes(
A: torch.Tensor,
B: torch.Tensor,
check_finite: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Port of the scipy.linalg.orthogonal_procrustes function:
https://github.com/scipy/scipy/blob/v1.13.0/scipy/linalg/_procrustes.py
Computes the matrix solution of the orthogonal Procrustes problem.
Given two matrices, A and B, find the orthogonal matrix that most closely
maps A to B using the algorithm in [1].
Args:
A (torch.Tensor):
The input matrix.
B (torch.Tensor):
The target matrix.
check_finite (bool):
Whether to check that the input matrices contain only finite
numbers. Disabling may give a performance gain, but may result in
problems (crashes, non-termination) if the inputs do contain infinities
or NaNs. (Default is ``True``)
Returns:
(Tuple[torch.Tensor, torch.Tensor]):
(R, scale):
R (torch.Tensor):
The matrix solution of the orthogonal Procrustes problem.
Minimizes the Frobenius norm of ``(A @ R) - B``, subject to
``R.T @ R = I``.
scale (torch.Tensor):
Sum of the singular values of ``A.T @ B``.
References:
[1] Peter H. Schonemann, "A generalized solution of the orthogonal
Procrustes problem", Psychometrica -- Vol. 31, No. 1, March, 1966.
:doi:`10.1007/BF02289451`
"""
if check_finite:
if not torch.isfinite(A).all() or not torch.isfinite(B).all():
raise ValueError("Input contains non-finite values.")
assert A.shape == B.shape, 'Input matrices must have the same shape.'
assert A.ndim == 2, 'Input matrices must be 2D.'

U, S, V = torch.linalg.svd((B.T @ A).T, full_matrices=False)
R = U @ V
scale = S.sum()
return R, scale


#########################################################
############ INTRA-MODULE HELPER FUNCTIONS ##############
#########################################################
Expand Down

0 comments on commit 90f0e31

Please sign in to comment.