Skip to content

Commit

Permalink
Make CrossvalMultipleRegularizations more effective for very large da…
Browse files Browse the repository at this point in the history
…tasets (#37)

Signed-off-by: Matthias Kümmmerer <[email protected]>
  • Loading branch information
matthias-k authored Nov 18, 2023
1 parent c4cdbb2 commit a66a5c5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions pysaliency/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,18 @@ def __init__(self, stimuli, fixations, regularization_models: OrderedDict, cross
verbose=False
)

mean_area = np.mean([x[2]*x[3] for x in X_areas])
self.mean_area = mean_area

self.X = fixations_to_scikit_learn(
self.fixations,
normalize=self.stimuli,
keep_aspect=True, add_shape=False, add_fixation_number=True, verbose=False
)

real_areas = [self.stimuli.sizes[n][0]*self.stimuli.sizes[n][1] for n in self.fixations.n]
areas_gold = [x[2]*x[3] for x in X_areas]
stimuli_sizes = np.array(self.stimuli.sizes)
real_areas = stimuli_sizes[self.fixations.n, 0] * stimuli_sizes[self.fixations.n, 1]
areas_gold = X_areas[:, 2] * X_areas[:, 3]
self.mean_area = np.mean(areas_gold)

correction = np.log(areas_gold) - np.log(real_areas)
self.regularization_log_likelihoods = []

Expand Down
3 changes: 2 additions & 1 deletion tests/test_baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,5 @@ def test_crossval_multiple_regularizations(stimuli, fixation_trains):
log_regularizations = [0.1, 0.2]

score = estimator.score(log_bandwidth, *log_regularizations)
assert isinstance(score, float)
assert isinstance(score, float)
np.testing.assert_allclose(score, -1.4673831679692528e-10)

0 comments on commit a66a5c5

Please sign in to comment.