From 1b2e59142e463a636d3ff7a8833db7bdb26f6115 Mon Sep 17 00:00:00 2001 From: matthias-k Date: Sat, 27 Jan 2024 00:33:54 +0100 Subject: [PATCH] Save memory in gold standard crossvalidation (#46) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use integer indices instead of binary masks in image-subject-crossval. Signed-off-by: Matthias Kümmmerer --- pysaliency/baseline_utils.py | 29 ++++++++++++++++++++--------- tests/test_baseline_utils.py | 2 -- tests/test_crossvalidation.py | 6 ++---- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/pysaliency/baseline_utils.py b/pysaliency/baseline_utils.py index 5794fa6..ecd8a7f 100644 --- a/pysaliency/baseline_utils.py +++ b/pysaliency/baseline_utils.py @@ -142,6 +142,13 @@ def __iter__(self): if test_inds.sum() == 0 or train_inds.sum() == 0: #print("Skipping") continue + + # scikit at some point loads all indices from all crossvalidation folds into memory + # if we use the binary masks, this will use a lot of memory, hence + # we convert to indices here + train_inds = np.nonzero(train_inds)[0] + test_inds = np.nonzero(test_inds)[0] + yield train_inds, test_inds def __len__(self): @@ -168,6 +175,13 @@ def __iter__(self): test_inds[chunk] = 1 test_inds = test_inds > 0.5 train_inds = image_inds & ~test_inds + + # scikit at some point loads all indices from all crossvalidation folds into memory + # if we use the binary masks, this will use a lot of memory, hence + # we convert to indices here + train_inds = np.nonzero(train_inds)[0] + test_inds = np.nonzero(test_inds)[0] + yield train_inds, test_inds def __len__(self): @@ -361,15 +375,12 @@ class CrossvalMultipleRegularizations(object): verbose: verbosity level for cross_val_score """ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, crossvalidation, n_jobs=None, verbose=False): - self.stimuli = stimuli - self.fixations = fixations - self.cv = crossvalidation self.n_jobs = n_jobs self.verbose = verbose X_areas = fixations_to_scikit_learn( - self.fixations, normalize=stimuli, + fixations, normalize=stimuli, keep_aspect=True, add_shape=True, verbose=False @@ -377,13 +388,13 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross self.X = fixations_to_scikit_learn( - self.fixations, - normalize=self.stimuli, + fixations, + normalize=stimuli, keep_aspect=True, add_shape=False, add_fixation_number=True, verbose=False ) - stimuli_sizes = np.array(self.stimuli.sizes) - real_areas = stimuli_sizes[self.fixations.n, 0] * stimuli_sizes[self.fixations.n, 1] + stimuli_sizes = np.array(stimuli.sizes) + real_areas = stimuli_sizes[fixations.n, 0] * stimuli_sizes[fixations.n, 1] areas_gold = X_areas[:, 2] * X_areas[:, 3] self.mean_area = np.mean(areas_gold) @@ -393,7 +404,7 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross self.regularization_models = [] self.params = ['log_bandwidth'] for model_name, model in regularization_models.items(): - model_lls = model.log_likelihoods(self.stimuli, self.fixations, verbose=True) + model_lls = model.log_likelihoods(stimuli, fixations, verbose=True) self.regularization_log_likelihoods.append(model_lls - correction) self.params.append('log_{}'.format(model_name)) diff --git a/tests/test_baseline_utils.py b/tests/test_baseline_utils.py index eff463c..2f97a8d 100644 --- a/tests/test_baseline_utils.py +++ b/tests/test_baseline_utils.py @@ -144,8 +144,6 @@ def test_crossval_multiple_regularizations(stimuli, fixation_trains): regularization_models = OrderedDict([('model1', pysaliency.UniformModel()), ('model2', pysaliency.models.GaussianModel())]) crossvalidation = ScikitLearnImageCrossValidationGenerator(stimuli, fixation_trains) estimator = CrossvalMultipleRegularizations(stimuli, fixation_trains, regularization_models, crossvalidation) - assert estimator.stimuli is stimuli - assert estimator.fixations is fixation_trains assert estimator.cv is crossvalidation assert estimator.mean_area is not None assert estimator.X is not None diff --git a/tests/test_crossvalidation.py b/tests/test_crossvalidation.py index 63d3a84..4f4c3dd 100644 --- a/tests/test_crossvalidation.py +++ b/tests/test_crossvalidation.py @@ -92,10 +92,8 @@ def test_image_subject_crossvalidation(stimuli, fixation_trains): cv = ScikitLearnImageSubjectCrossValidationGenerator(stimuli, fixation_trains) assert unpack_crossval(cv) == [ - ([False, False, False, True, True, False, False, False, False], - [True, True, True, False, False, False, False, False, False]), - ([True, True, True, False, False, False, False, False, False], - [False, False, False, True, True, False, False, False, False]) + ([3, 4], [0, 1, 2]), + ([0, 1, 2], [3, 4]) ] X = fixations_to_scikit_learn(fixation_trains, normalize=stimuli, add_shape=True)