Skip to content

Commit

Permalink
CP_NN_HALS_minibatch jitter zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Sep 6, 2024
1 parent b43715e commit fd44907
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# import cuml.decomposition
# import cupy

from . import similarity
from . import similarity, torch_helpers


###########################
Expand Down Expand Up @@ -680,6 +680,8 @@ class CP_NN_HALS_minibatch:
stored.
kwargs_dataloader (dict):
Keyword arguments to pass to torch.utils.data.DataLoader.
jitter_zeros (bool):
Whether or not to jitter any 0s in the factors to be small values.
random_state (int):
Random seed for reproducibility.
verbose (bool):
Expand Down Expand Up @@ -710,6 +712,7 @@ def __init__(
# 'prefetch_factor': 2,
# 'persistent_workers': False,
},
jitter_zeros: bool = False,
random_state: int = 0,
verbose: bool = True,
):
Expand All @@ -723,6 +726,7 @@ def __init__(
self.n_epochs = int(n_epochs)
self.device = torch.device(device)
self.kwargs_dataloader = dict(kwargs_dataloader)
self.jitter_zeros = jitter_zeros
self.random_state = random_state
self.verbose = verbose

Expand Down Expand Up @@ -764,6 +768,14 @@ def fit(self, X):
self (CP_NN_HALS_minibatch object):
Returns the CP_NN_HALS_minibatch object.
"""
## Clear CUDA cache
factors_tmp = copy.deepcopy(self.factors) if self.factors is not None else None
self.factors = None
torch_helpers.clear_cuda_cache()
self.factors = copy.deepcopy(factors_tmp)
del factors_tmp

## Initialize factors
self.init_factors(X) if self.factors is None else None

## Make dataset (if X is not a dataset)
Expand Down Expand Up @@ -798,7 +810,7 @@ def __getitem__(self, idx):
kwargs_tmp.pop('n_iter_max', None)
## Fit
for i_epoch in tqdm(range(self.n_epochs), desc='CP_NN_HALS_minibatch', leave=True, disable=self.verbose==False):
for i_iter, (X_batch, idx_batch) in enumerate(dataloader):
for i_iter, (X_batch, idx_batch) in tqdm(enumerate(dataloader), desc='Batch', leave=False, disable=self.verbose==False, total=len(dataloader)):
X_batch = X_batch.to(self.device)
idx_batch = idx_batch.to(self.device)

Expand All @@ -811,15 +823,31 @@ def __getitem__(self, idx):
)
model.fit(X_batch)

factors_batch = model.decomposition_
self.factors.factors[0][idx_batch] = factors_batch.factors[0]
self.factors.factors[1:] = factors_batch.factors[1:]
self.factors.weights = factors_batch.weights

## Drop into self.losses
i_total = np.array(list(self.losses.keys()), dtype=int)[:, 1].max() if len(self.losses) > 0 else 0
self.losses.update({(i_epoch, i_total + ii): loss.item() for ii, loss in enumerate(model.errors_)})

factors_batch = model.decomposition_

violation = False
## If any of the values are NaN, then throw warning and skip
if any([torch.any(torch.isnan(f)) for f in factors_batch.factors]) or torch.any(torch.isnan(factors_batch.weights)):
print('WARNING: NaNs found. Skipping batch.')
violation = True
## Jitter any 0s in the factors to be small values
elif self.jitter_zeros:
## Only do it for the batched factors
factors_batch.factors[0] = torch.where(
factors_batch.factors[0] <= 0,
1e-6 * torch.rand_like(factors_batch.factors[0]),
factors_batch.factors[0],
)

if not violation:
self.factors.factors[0][idx_batch] = factors_batch.factors[0]
self.factors.factors[1:] = factors_batch.factors[1:]
self.factors.weights = factors_batch.weights

return self

@property
Expand Down

0 comments on commit fd44907

Please sign in to comment.