Skip to content

Commit

Permalink
refactor: Update CP_NN_HALS_minibatch initialization with default kwa…
Browse files Browse the repository at this point in the history
…rgs. Also allow wrapping of a dataset as the input data.
  • Loading branch information
RichieHakim committed Sep 4, 2024
1 parent 347fea0 commit b43715e
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,16 @@ class CP_NN_HALS_minibatch:
"""
def __init__(
self,
kwargs_CP_NN_HALS: dict = {},
kwargs_CP_NN_HALS: dict = {
'rank': 10,
'init': 'svd',
'svd': 'truncated_svd',
'non_negative': True,
'random_state': 0,
'normalize_factors': False,
'mask': None,
'svd_mask_repeats': 5,
},
batch_dimension: int = 0,
batch_size: int = 10,
n_iter_batch: int = 10,
Expand Down Expand Up @@ -757,11 +766,25 @@ def fit(self, X):
"""
self.init_factors(X) if self.factors is None else None

## Make dataset (if X is not a dataset)
if isinstance(X, torch.Tensor):
dataset = torch.utils.data.TensorDataset(
X.moveaxis(self.batch_dimension, 0),
torch.arange(X.shape[self.batch_dimension]),
)
elif isinstance(X, Dataset):
class DatasetWrapper(Dataset):
def __init__(self, dataset):
super(DatasetWrapper, self).__init__()
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx], idx
dataset = DatasetWrapper(X)
else:
raise ValueError('X must be a torch.Tensor or torch.utils.data.Dataset.')
## Make dataloader
dataset = torch.utils.data.TensorDataset(
X.moveaxis(self.batch_dimension, 0),
torch.arange(X.shape[self.batch_dimension]),
)
kwargs_dataloader_tmp = copy.deepcopy(self.kwargs_dataloader)
kwargs_dataloader_tmp.pop('batch_size', None)
dataloader = torch.utils.data.DataLoader(
Expand Down

0 comments on commit b43715e

Please sign in to comment.