Skip to content

Commit

Permalink
New CP_NN_HALS_minibatch class
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Jun 25, 2024
1 parent 8887245 commit d1f5c48
Showing 1 changed file with 180 additions and 0 deletions.
180 changes: 180 additions & 0 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import tensorly as tl

# import cuml
# import cuml.decomposition
# import cupy

from . import similarity


###########################
########## PCA ############
Expand Down Expand Up @@ -654,6 +657,183 @@ def ZCA_whiten(
return X_zca


class CP_NN_HALS_minibatch:
"""
Minibatch version of Tensorly's CP_NN_HALS. Optimization proceeds by
randomly sampling indices along the defined batching dimension.
RH 2024
Args:
kwargs_CP_NN_HALS (dict):
Keyword arguments to pass to Tensorly's CP_NN_HALS.
batch_dimension (int):
Dimension along which to batch the data (0 to n_dims-1).
batch_size (int):
Number of samples to batch.
n_iter_batch (int):
Number of optimization iterations to perform on each batch.
n_epochs (int):
Number of epochs to perform.
device (str):
Device to transfer each batch to and where the factors will be
stored.
kwargs_dataloader (dict):
Keyword arguments to pass to torch.utils.data.DataLoader.
random_state (int):
Random seed for reproducibility.
verbose (bool):
Whether or not to print progress.
"""
def __init__(
self,
kwargs_CP_NN_HALS: dict = {},
batch_dimension: int = 0,
batch_size: int = 10,
n_iter_batch: int = 10,
n_epochs: int = 10,
device: str = 'cpu',
kwargs_dataloader: dict = {
'drop_last': True,
'shuffle': True,
# 'num_workers': 0,
# 'pin_memory': False,
# 'prefetch_factor': 2,
# 'persistent_workers': False,
},
random_state: int = 0,
verbose: bool = True,
):
"""
Initializes the CP_NN_HALS_minibatch module with the provided parameters.
"""
self.kwargs_CP_NN_HALS = kwargs_CP_NN_HALS
self.batch_dimension = int(batch_dimension)
self.batch_size = int(batch_size)
self.n_iter_batch = int(n_iter_batch)
self.n_epochs = int(n_epochs)
self.device = torch.device(device)
self.kwargs_dataloader = dict(kwargs_dataloader)
self.random_state = random_state
self.verbose = verbose

self.factors = None
self.losses = {}
tl.set_backend('pytorch')

def init_factors(self, X):
"""
Initializes the factors.
"""
kwargs_default = {
'rank': 10,
'init': 'svd',
'svd': 'truncated_svd',
'non_negative': True,
'random_state': 0,
'normalize_factors': False,
'mask': None,
'svd_mask_repeats': 5,
}
self.factors = tl.decomposition._cp.initialize_cp(
tensor=X.moveaxis(self.batch_dimension, 0),
**{k: self.kwargs_CP_NN_HALS.get(k, v) for k, v in kwargs_default.items()},
)
self.factors.factors = [f.to(self.device) for f in self.factors.factors]
self.factors.weights = self.factors.weights.to(self.device)
return self.factors

def fit(self, X):
"""
Fits the CP_NN_HALS_minibatch module to the input data.
Args:
X (torch.Tensor):
The input data to fit the CP_NN_HALS_minibatch module to.
Returns:
self (CP_NN_HALS_minibatch object):
Returns the CP_NN_HALS_minibatch object.
"""
self.init_factors(X) if self.factors is None else None

## Make dataloader
dataset = torch.utils.data.TensorDataset(
X.moveaxis(self.batch_dimension, 0),
torch.arange(X.shape[self.batch_dimension]),
)
self.kwargs_dataloader.pop('batch_size', None)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
**self.kwargs_dataloader,
)

self.kwargs_CP_NN_HALS.pop('init', None)
self.kwargs_CP_NN_HALS.pop('n_iter_max', None)
## Fit
for i_epoch in tqdm(range(self.n_epochs), desc='CP_NN_HALS_minibatch', leave=True, disable=self.verbose==False):
for i_iter, (X_batch, idx_batch) in enumerate(dataloader):
X_batch = X_batch.to(self.device)
idx_batch = idx_batch.to(self.device)

factors_batch = copy.deepcopy(self.factors)
factors_batch.factors[0] = factors_batch.factors[0][idx_batch]
model = tl.decomposition.CP_NN_HALS(
n_iter_max=self.n_iter_batch,
init=factors_batch,
**self.kwargs_CP_NN_HALS,
)
model.fit(X_batch)

factors_batch = model.decomposition_
self.factors.factors[0][idx_batch] = factors_batch.factors[0]
self.factors.factors[1:] = factors_batch.factors[1:]
self.factors.weights = factors_batch.weights

## Drop into self.losses
i_total = np.array(list(self.losses.keys()), dtype=int)[:, 1].max() if len(self.losses) > 0 else 0
self.losses.update({(i_epoch, i_total + ii): loss.item() for ii, loss in enumerate(model.errors_)})

return self

@property
def components_(self):
"""
Returns the components of the decomposition.
"""
return self.factors.factors

def score(self, X):
"""
Returns the explained variance.
Reconstructs X_hat using the factors and scores against the input X.
Args:
X (torch.Tensor):
The input data to score against.
"""
evr = similarity.cp_reconstruction_EVR(
tensor_dense=X,
tensor_CP=self.factors,
)
return evr

def to(self, device):
"""
Moves the factors to the specified device.
Args:
device (str):
Device to move the factors to.
"""
if self.factors is not None:
self.factors.factors = [f.to(device) for f in self.factors.factors]
self.factors.weights = self.factors.weights.to(device)
self.device = device
return self


#######################################
########## Incremental PCA ############
#######################################
Expand Down

0 comments on commit d1f5c48

Please sign in to comment.