Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save memory in gold standard crossvalidation #46

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions pysaliency/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def __iter__(self):
if test_inds.sum() == 0 or train_inds.sum() == 0:
#print("Skipping")
continue

# scikit at some point loads all indices from all crossvalidation folds into memory
# if we use the binary masks, this will use a lot of memory, hence
# we convert to indices here
train_inds = np.nonzero(train_inds)[0]
test_inds = np.nonzero(test_inds)[0]

yield train_inds, test_inds

def __len__(self):
Expand All @@ -168,6 +175,13 @@ def __iter__(self):
test_inds[chunk] = 1
test_inds = test_inds > 0.5
train_inds = image_inds & ~test_inds

# scikit at some point loads all indices from all crossvalidation folds into memory
# if we use the binary masks, this will use a lot of memory, hence
# we convert to indices here
train_inds = np.nonzero(train_inds)[0]
test_inds = np.nonzero(test_inds)[0]

yield train_inds, test_inds

def __len__(self):
Expand Down Expand Up @@ -361,29 +375,26 @@ class CrossvalMultipleRegularizations(object):
verbose: verbosity level for cross_val_score
"""
def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, n_jobs=None, verbose=False):
self.stimuli = stimuli
self.fixations = fixations

self.cv = crossvalidation
self.n_jobs = n_jobs
self.verbose = verbose

X_areas = fixations_to_scikit_learn(
self.fixations, normalize=stimuli,
fixations, normalize=stimuli,
keep_aspect=True,
add_shape=True,
verbose=False
)


self.X = fixations_to_scikit_learn(
self.fixations,
normalize=self.stimuli,
fixations,
normalize=stimuli,
keep_aspect=True, add_shape=False, add_fixation_number=True, verbose=False
)

stimuli_sizes = np.array(self.stimuli.sizes)
real_areas = stimuli_sizes[self.fixations.n, 0] * stimuli_sizes[self.fixations.n, 1]
stimuli_sizes = np.array(stimuli.sizes)
real_areas = stimuli_sizes[fixations.n, 0] * stimuli_sizes[fixations.n, 1]
areas_gold = X_areas[:, 2] * X_areas[:, 3]
self.mean_area = np.mean(areas_gold)

Expand All @@ -393,7 +404,7 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross
self.regularization_models = []
self.params = ['log_bandwidth']
for model_name, model in regularization_models.items():
model_lls = model.log_likelihoods(self.stimuli, self.fixations, verbose=True)
model_lls = model.log_likelihoods(stimuli, fixations, verbose=True)
self.regularization_log_likelihoods.append(model_lls - correction)
self.params.append('log_{}'.format(model_name))

Expand Down
2 changes: 0 additions & 2 deletions tests/test_baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def test_crossval_multiple_regularizations(stimuli, fixation_trains):
regularization_models = OrderedDict([('model1', pysaliency.UniformModel()), ('model2', pysaliency.models.GaussianModel())])
crossvalidation = ScikitLearnImageCrossValidationGenerator(stimuli, fixation_trains)
estimator = CrossvalMultipleRegularizations(stimuli, fixation_trains, regularization_models, crossvalidation)
assert estimator.stimuli is stimuli
assert estimator.fixations is fixation_trains
assert estimator.cv is crossvalidation
assert estimator.mean_area is not None
assert estimator.X is not None
Expand Down
6 changes: 2 additions & 4 deletions tests/test_crossvalidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ def test_image_subject_crossvalidation(stimuli, fixation_trains):
cv = ScikitLearnImageSubjectCrossValidationGenerator(stimuli, fixation_trains)

assert unpack_crossval(cv) == [
([False, False, False, True, True, False, False, False, False],
[True, True, True, False, False, False, False, False, False]),
([True, True, True, False, False, False, False, False, False],
[False, False, False, True, True, False, False, False, False])
([3, 4], [0, 1, 2]),
([0, 1, 2], [3, 4])
]

X = fixations_to_scikit_learn(fixation_trains, normalize=stimuli, add_shape=True)
Expand Down
Loading