Skip to content

Commit

Permalink
New PCA implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 30, 2024
1 parent b2f9d72 commit cf04d21
Show file tree
Hide file tree
Showing 2 changed files with 475 additions and 0 deletions.
311 changes: 311 additions & 0 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple, Optional
import math

import sklearn.decomposition
import numpy as np
import scipy.interpolate
Expand Down Expand Up @@ -100,6 +103,314 @@ def simple_pca(X , n_components=None , mean_sub=True, zscore=False, plot_pref=Fa
return components , scores , decomp.explained_variance_ratio_


def svd_flip(
u: torch.Tensor,
v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sign correction to ensure deterministic output from SVD.
The output from SVD does not have a unique sign. This function corrects the
sign of the output to ensure deterministic output from the SVD function.
RH 2024
Args:
u (torch.Tensor):
The left singular vectors.
v (torch.Tensor):
The right singular vectors.
Returns:
(Tuple[torch.Tensor, torch.Tensor]):
u (torch.Tensor):
The corrected left singular vectors.
v (torch.Tensor):
The corrected right singular vectors.
"""
as_tensor = lambda x: torch.as_tensor(x) if isinstance(x, np.ndarray) else x
u, v = (as_tensor(var) for var in (u, v))

max_abs_cols = torch.argmax(torch.abs(u), dim=0)
signs = torch.sign(u[max_abs_cols, range(u.shape[1])])
u *= signs
v *= signs.unsqueeze(-1)
return u, v


class PCA(torch.nn.Module, sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
"""
Principal Component Analysis (PCA) module.
This module performs PCA on the input data and returns the principal
components and the explained variance. The PCA is performed using the
singular value decomposition (SVD) method. This class follows sklearn's PCA
implementation and style and is subclassed from torch.nn.Module,
sklearn.base.BaseEstimator, and sklearn.base.TransformerMixin. The
decomposed variables (components, explained_variance, etc.) are stored as
buffers so that they are stored in the state_dict, respond to .to() and
.cuda(), and are saved when the model is saved.
RH 2024
Args:
n_components (Optional[int]):
Number of principal components to retain. If ``None``, all
components are retained.
center (bool):
If ``True``, the data is mean-subtracted before performing SVD.
zscale (bool):
If ``True``, the data is z-scored before performing SVD. Equivalent
of doing eigenvalue decomposition on the correlation matrix.
whiten (bool):
If ``True``, the principal components are divided by the square root
of the explained variance.
use_lowRank (bool):
If ``True``, the low-rank SVD is used. This is faster but less
accurate. Uses torch.svd_lowrank instead of torch.linalg.svd.
lowRank_niter (int):
Number of subspace iterations for low-rank SVD. See
torch.svd_lowrank for more details.
Attributes:
n_components (int):
Number of principal components to retain.
whiten (bool):
If ``True``, the principal components are divided by the square root
of the explained variance.
device (str):
The device where the tensors will be stored.
dtype (torch.dtype):
The data type to use for the tensor.
components (torch.Tensor):
The principal components.
explained_variance (torch.Tensor):
The explained variance.
explained_variance_ratio (torch.Tensor):
The explained variance ratio.
Example:
.. highlight:: python
.. code-block:: python
X = torch.randn(100, 10)
pca = PCA(n_components=5)
pca.fit(X)
X_pca = pca.transform(X)
"""
def __init__(
self,
n_components: Optional[int] = None,
center: bool = True,
zscale: bool = False,
whiten: bool = False,
use_lowRank: bool = False,
lowRank_niter: int = 2,
):
"""
Initializes the PCA module with the provided parameters.
"""
super(PCA, self).__init__()
self.n_components = n_components
self.center = center
self.zscale = zscale
self.whiten = whiten
self.use_lowRank = use_lowRank
self.lowRank_niter = lowRank_niter

def prepare_input(
self,
X: torch.Tensor,
center: bool,
zscale: bool
) -> torch.Tensor:
"""
Prepares the input data for PCA.
Args:
X (torch.Tensor):
The input data to prepare.
center (bool):
If ``True``, the data is mean-subtracted.
zscale (bool):
If ``True``, the data is z-scored.
Returns:
(torch.Tensor):
The prepared input data.
"""
if isinstance(X, np.ndarray):
X = torch.as_tensor(X)
assert isinstance(X, torch.Tensor), 'Input must be a torch.Tensor.'
X = X[:, None] if X.ndim == 1 else X
assert X.ndim == 2, 'Input must be 2D.'

if center:
mean_ = torch.mean(X, dim=0)
X = X - mean_
self.register_buffer('mean_', mean_)
if zscale:
std_ = torch.std(X, dim=0)
X = X / std_
self.register_buffer('std_', std_)
return X

def fit(
self,
X: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fits the PCA module to the input data.
Args:
X (torch.Tensor):
The input data to fit the PCA module to. Should be shape
(n_samples, n_features).
Returns:
self (PCA object):
Returns the PCA object.
"""
self._fit(X)
return self

def _fit(
self,
X: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fits the PCA module to the input data.
Args:
X (torch.Tensor):
The input data to fit the PCA module to. Should be shape
(n_samples, n_features).
Returns:
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
U (torch.Tensor):
The left singular vectors. Shape (n_samples, n_components).
S (torch.Tensor):
The singular values. Shape (n_components,).
V (torch.Tensor):
The right singular vectors. Shape (n_features,
n_components).
"""
self.n_samples_, self.n_features_ = X.shape
self.n_components_ = min(self.n_components, self.n_features_) if self.n_components is not None else self.n_features_

X = self.prepare_input(X, center=self.center, zscale=self.zscale)
if self.use_lowRank:
U, S, Vh = torch.svd_lowrank(X, q=self.n_components_, niter=self.lowRank_niter)
Vh = Vh.T ## torch.svd_lowrank returns Vh transposed.
else:
U, S, Vh = torch.linalg.svd(X, full_matrices=False) ## U: (n_samples, n_features), S: (n_features,), Vh: (n_features, n_features). Vh is already transposed.
U, Vh = svd_flip(U, Vh)

explained_variance_ = S**2 / (self.n_samples_ - 1)
explained_variance_ratio_ = explained_variance_ / torch.sum(explained_variance_)

components_ = Vh[:self.n_components_]
singular_values_ = S[:self.n_components_]
explained_variance_ = explained_variance_[:self.n_components_]
explained_variance_ratio_ = explained_variance_ratio_[:self.n_components_]

[self.register_buffer(name, value) for name, value in zip(
['components_', 'singular_values_', 'explained_variance_', 'explained_variance_ratio_'],
[components_, singular_values_, explained_variance_, explained_variance_ratio_]
)]

return U, S, Vh

def transform(
self,
X: torch.Tensor,
y: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Transforms the input data using the fitted PCA module.
Args:
X (torch.Tensor):
The input data to transform.
y (Optional[torch.Tensor]):
Ignored. This parameter exists to match the sklearn API.
Returns:
(torch.Tensor):
The transformed data. Will be shape (n_samples, n_components).
"""
assert hasattr(self, 'components_'), 'PCA module must be fitted before transforming data.'
X = self.prepare_input(X, center=self.center, zscale=self.zscale)
X_transformed = X @ self.components_.T
if self.whiten:
X_transformed /= torch.sqrt(self.explained_variance_)
return X_transformed

def fit_transform(
self,
X: torch.Tensor
) -> torch.Tensor:
"""
Fits the PCA module to the input data and transforms the input data.
Args:
X (torch.Tensor):
The input data to fit the PCA module to and transform.
Returns:
(torch.Tensor):
The transformed data.
"""
self.n_samples_, self.n_features_ = X.shape

U, S, V = self._fit(X)
U = U[:, :self.n_components_]

if self.whiten:
U *= math.sqrt(self.n_samples_ - 1)
else:
U *= S[:self.n_components_]

return U

def inverse_transform(
self,
X: torch.Tensor
) -> torch.Tensor:
"""
Inverse transforms the input data using the fitted PCA module.
Args:
X (torch.Tensor):
The input data to inverse transform. Should be shape (n_samples,
n_components).
Returns:
(torch.Tensor):
The inverse transformed data. Will be shape (n_samples,
n_features).
"""
assert hasattr(self, 'components_'), 'PCA module must be fitted before transforming data.'
X = self.prepare_input(X, center=False, zscale=False)

if self.whiten:
scaled_components = torch.sqrt(self.explained_variance_) * self.components_
else:
scaled_components = self.components_

X = X @ scaled_components

if self.zscale:
assert hasattr(self, 'std_'), 'self.zscale is True, but std_ is not found.'
X = X * self.std_
if self.center:
assert hasattr(self, 'mean_'), 'self.center is True, but mean_ is not found.'
X = X + self.mean_

return X


def torch_pca(
X_in,
device='cpu',
Expand Down
Loading

0 comments on commit cf04d21

Please sign in to comment.