Skip to content

Commit

Permalink
Save memory in gold standard crossvalidation (#46)
Browse files Browse the repository at this point in the history
use integer indices instead of binary masks in image-subject-crossval.

Signed-off-by: Matthias Kümmmerer <[email protected]>
  • Loading branch information
matthias-k authored Jan 26, 2024
1 parent 332bfb6 commit 1b2e591
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
29 changes: 20 additions & 9 deletions pysaliency/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def __iter__(self):
if test_inds.sum() == 0 or train_inds.sum() == 0:
#print("Skipping")
continue

# scikit at some point loads all indices from all crossvalidation folds into memory
# if we use the binary masks, this will use a lot of memory, hence
# we convert to indices here
train_inds = np.nonzero(train_inds)[0]
test_inds = np.nonzero(test_inds)[0]

yield train_inds, test_inds

def __len__(self):
Expand All @@ -168,6 +175,13 @@ def __iter__(self):
test_inds[chunk] = 1
test_inds = test_inds > 0.5
train_inds = image_inds & ~test_inds

# scikit at some point loads all indices from all crossvalidation folds into memory
# if we use the binary masks, this will use a lot of memory, hence
# we convert to indices here
train_inds = np.nonzero(train_inds)[0]
test_inds = np.nonzero(test_inds)[0]

yield train_inds, test_inds

def __len__(self):
Expand Down Expand Up @@ -361,29 +375,26 @@ class CrossvalMultipleRegularizations(object):
verbose: verbosity level for cross_val_score
"""
def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, n_jobs=None, verbose=False):
self.stimuli = stimuli
self.fixations = fixations

self.cv = crossvalidation
self.n_jobs = n_jobs
self.verbose = verbose

X_areas = fixations_to_scikit_learn(
self.fixations, normalize=stimuli,
fixations, normalize=stimuli,
keep_aspect=True,
add_shape=True,
verbose=False
)


self.X = fixations_to_scikit_learn(
self.fixations,
normalize=self.stimuli,
fixations,
normalize=stimuli,
keep_aspect=True, add_shape=False, add_fixation_number=True, verbose=False
)

stimuli_sizes = np.array(self.stimuli.sizes)
real_areas = stimuli_sizes[self.fixations.n, 0] * stimuli_sizes[self.fixations.n, 1]
stimuli_sizes = np.array(stimuli.sizes)
real_areas = stimuli_sizes[fixations.n, 0] * stimuli_sizes[fixations.n, 1]
areas_gold = X_areas[:, 2] * X_areas[:, 3]
self.mean_area = np.mean(areas_gold)

Expand All @@ -393,7 +404,7 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross
self.regularization_models = []
self.params = ['log_bandwidth']
for model_name, model in regularization_models.items():
model_lls = model.log_likelihoods(self.stimuli, self.fixations, verbose=True)
model_lls = model.log_likelihoods(stimuli, fixations, verbose=True)
self.regularization_log_likelihoods.append(model_lls - correction)
self.params.append('log_{}'.format(model_name))

Expand Down
2 changes: 0 additions & 2 deletions tests/test_baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def test_crossval_multiple_regularizations(stimuli, fixation_trains):
regularization_models = OrderedDict([('model1', pysaliency.UniformModel()), ('model2', pysaliency.models.GaussianModel())])
crossvalidation = ScikitLearnImageCrossValidationGenerator(stimuli, fixation_trains)
estimator = CrossvalMultipleRegularizations(stimuli, fixation_trains, regularization_models, crossvalidation)
assert estimator.stimuli is stimuli
assert estimator.fixations is fixation_trains
assert estimator.cv is crossvalidation
assert estimator.mean_area is not None
assert estimator.X is not None
Expand Down
6 changes: 2 additions & 4 deletions tests/test_crossvalidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ def test_image_subject_crossvalidation(stimuli, fixation_trains):
cv = ScikitLearnImageSubjectCrossValidationGenerator(stimuli, fixation_trains)

assert unpack_crossval(cv) == [
([False, False, False, True, True, False, False, False, False],
[True, True, True, False, False, False, False, False, False]),
([True, True, True, False, False, False, False, False, False],
[False, False, False, True, True, False, False, False, False])
([3, 4], [0, 1, 2]),
([0, 1, 2], [3, 4])
]

X = fixations_to_scikit_learn(fixation_trains, normalize=stimuli, add_shape=True)
Expand Down

0 comments on commit 1b2e591

Please sign in to comment.