diff --git a/bnpm/decomposition.py b/bnpm/decomposition.py index 660eb50..ad801dd 100644 --- a/bnpm/decomposition.py +++ b/bnpm/decomposition.py @@ -16,7 +16,7 @@ # import cuml.decomposition # import cupy -from . import similarity +from . import similarity, torch_helpers ########################### @@ -680,6 +680,8 @@ class CP_NN_HALS_minibatch: stored. kwargs_dataloader (dict): Keyword arguments to pass to torch.utils.data.DataLoader. + jitter_zeros (bool): + Whether or not to jitter any 0s in the factors to be small values. random_state (int): Random seed for reproducibility. verbose (bool): @@ -710,6 +712,7 @@ def __init__( # 'prefetch_factor': 2, # 'persistent_workers': False, }, + jitter_zeros: bool = False, random_state: int = 0, verbose: bool = True, ): @@ -723,6 +726,7 @@ def __init__( self.n_epochs = int(n_epochs) self.device = torch.device(device) self.kwargs_dataloader = dict(kwargs_dataloader) + self.jitter_zeros = jitter_zeros self.random_state = random_state self.verbose = verbose @@ -764,6 +768,14 @@ def fit(self, X): self (CP_NN_HALS_minibatch object): Returns the CP_NN_HALS_minibatch object. """ + ## Clear CUDA cache + factors_tmp = copy.deepcopy(self.factors) if self.factors is not None else None + self.factors = None + torch_helpers.clear_cuda_cache() + self.factors = copy.deepcopy(factors_tmp) + del factors_tmp + + ## Initialize factors self.init_factors(X) if self.factors is None else None ## Make dataset (if X is not a dataset) @@ -798,7 +810,7 @@ def __getitem__(self, idx): kwargs_tmp.pop('n_iter_max', None) ## Fit for i_epoch in tqdm(range(self.n_epochs), desc='CP_NN_HALS_minibatch', leave=True, disable=self.verbose==False): - for i_iter, (X_batch, idx_batch) in enumerate(dataloader): + for i_iter, (X_batch, idx_batch) in tqdm(enumerate(dataloader), desc='Batch', leave=False, disable=self.verbose==False, total=len(dataloader)): X_batch = X_batch.to(self.device) idx_batch = idx_batch.to(self.device) @@ -811,15 +823,31 @@ def __getitem__(self, idx): ) model.fit(X_batch) - factors_batch = model.decomposition_ - self.factors.factors[0][idx_batch] = factors_batch.factors[0] - self.factors.factors[1:] = factors_batch.factors[1:] - self.factors.weights = factors_batch.weights - ## Drop into self.losses i_total = np.array(list(self.losses.keys()), dtype=int)[:, 1].max() if len(self.losses) > 0 else 0 self.losses.update({(i_epoch, i_total + ii): loss.item() for ii, loss in enumerate(model.errors_)}) + factors_batch = model.decomposition_ + + violation = False + ## If any of the values are NaN, then throw warning and skip + if any([torch.any(torch.isnan(f)) for f in factors_batch.factors]) or torch.any(torch.isnan(factors_batch.weights)): + print('WARNING: NaNs found. Skipping batch.') + violation = True + ## Jitter any 0s in the factors to be small values + elif self.jitter_zeros: + ## Only do it for the batched factors + factors_batch.factors[0] = torch.where( + factors_batch.factors[0] <= 0, + 1e-6 * torch.rand_like(factors_batch.factors[0]), + factors_batch.factors[0], + ) + + if not violation: + self.factors.factors[0][idx_batch] = factors_batch.factors[0] + self.factors.factors[1:] = factors_batch.factors[1:] + self.factors.weights = factors_batch.weights + return self @property