From a66a5c5919c735ccfdb7aa91be93f1149bbfc6bb Mon Sep 17 00:00:00 2001 From: matthias-k Date: Sat, 18 Nov 2023 10:47:55 +0100 Subject: [PATCH] Make CrossvalMultipleRegularizations more effective for very large datasets (#37) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmmerer --- pysaliency/baseline_utils.py | 9 +++++---- tests/test_baseline_utils.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pysaliency/baseline_utils.py b/pysaliency/baseline_utils.py index 9b7f43b..f81cbe6 100644 --- a/pysaliency/baseline_utils.py +++ b/pysaliency/baseline_utils.py @@ -369,8 +369,6 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross verbose=False ) - mean_area = np.mean([x[2]*x[3] for x in X_areas]) - self.mean_area = mean_area self.X = fixations_to_scikit_learn( self.fixations, @@ -378,8 +376,11 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross keep_aspect=True, add_shape=False, add_fixation_number=True, verbose=False ) - real_areas = [self.stimuli.sizes[n][0]*self.stimuli.sizes[n][1] for n in self.fixations.n] - areas_gold = [x[2]*x[3] for x in X_areas] + stimuli_sizes = np.array(self.stimuli.sizes) + real_areas = stimuli_sizes[self.fixations.n, 0] * stimuli_sizes[self.fixations.n, 1] + areas_gold = X_areas[:, 2] * X_areas[:, 3] + self.mean_area = np.mean(areas_gold) + correction = np.log(areas_gold) - np.log(real_areas) self.regularization_log_likelihoods = [] diff --git a/tests/test_baseline_utils.py b/tests/test_baseline_utils.py index e93571c..eff463c 100644 --- a/tests/test_baseline_utils.py +++ b/tests/test_baseline_utils.py @@ -156,4 +156,5 @@ def test_crossval_multiple_regularizations(stimuli, fixation_trains): log_regularizations = [0.1, 0.2] score = estimator.score(log_bandwidth, *log_regularizations) - assert isinstance(score, float) \ No newline at end of file + assert isinstance(score, float) + np.testing.assert_allclose(score, -1.4673831679692528e-10) \ No newline at end of file