From 64bcde42171bba2c167165790051346f9a6203ba Mon Sep 17 00:00:00 2001 From: matthias-k Date: Fri, 8 Mar 2024 13:28:11 +0100 Subject: [PATCH] StimulusDependentSaliencyMapModel (#53) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmmerer --- CHANGELOG.md | 1 + pysaliency/__init__.py | 1 + pysaliency/saliency_map_models.py | 28 +++++++++++++++++++++++++- tests/test_saliency_map_models.py | 33 +++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c29ce6b..58270d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog * 0.2.22 (dev): + * Feature: `StimulusDependentSaliencyMapModel` * Bugfix: The NUSEF dataset scaled some fixations not correctly to image coordinates. Also, we now account for some typos in the dataset source data. * Feature: CrossvalMultipleRegularizations and GeneralMixtureKernelDensityEstimator in baseline utils (names might change!) diff --git a/pysaliency/__init__.py b/pysaliency/__init__.py index 50b1010..e8c9f54 100755 --- a/pysaliency/__init__.py +++ b/pysaliency/__init__.py @@ -28,6 +28,7 @@ ExpSaliencyMapModel, DisjointUnionSaliencyMapModel, SubjectDependentSaliencyMapModel, + StimulusDependentSaliencyMapModel, ResizingSaliencyMapModel, BluringSaliencyMapModel, DigitizeMapModel, diff --git a/pysaliency/saliency_map_models.py b/pysaliency/saliency_map_models.py index eac0d1b..236e614 100644 --- a/pysaliency/saliency_map_models.py +++ b/pysaliency/saliency_map_models.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division, unicode_literals +from itertools import combinations import os from abc import ABCMeta, abstractmethod @@ -15,7 +16,7 @@ from .numba_utils import fill_fixation_map, auc_for_one_positive from .utils import TemporaryDirectory, run_matlab_cmd, Cache, average_values, deprecated_class, remove_trailing_nans -from .datasets import Stimulus, Fixations +from .datasets import Stimulus, Fixations, get_image_hash from .metrics import CC, NSS, SIM from .sampling_models import SamplingModelMixin @@ -934,6 +935,31 @@ def conditional_saliency_map(self, stimulus, x_hist, y_hist, t_hist, attributes= stimulus, x_hist, y_hist, t_hist, attributes=attributes, **kwargs) +class StimulusDependentSaliencyMapModel(SaliencyMapModel): + def __init__(self, stimuli_models, check_stimuli=True, fallback_model=None, **kwargs): + super(StimulusDependentSaliencyMapModel, self).__init__(**kwargs) + self.stimuli_models = stimuli_models + self.fallback_model = fallback_model + if check_stimuli: + self.check_stimuli() + + def check_stimuli(self): + for s1, s2 in tqdm(list(combinations(self.stimuli_models, 2))): + if not set(s1.stimulus_ids).isdisjoint(s2.stimulus_ids): + raise ValueError('Stimuli not disjoint') + + def _saliency_map(self, stimulus): + stimulus_hash = get_image_hash(stimulus) + for stimuli, model in self.stimuli_models.items(): + if stimulus_hash in stimuli.stimulus_ids: + return model.saliency_map(stimulus) + else: + if self.fallback_model is not None: + return self.fallback_model.saliency_map(stimulus) + else: + raise ValueError('stimulus not provided by these models') + + class ExpSaliencyMapModel(SaliencyMapModel): def __init__(self, parent_model): super(ExpSaliencyMapModel, self).__init__(caching=False) diff --git a/tests/test_saliency_map_models.py b/tests/test_saliency_map_models.py index c4b65b2..20f4e98 100644 --- a/tests/test_saliency_map_models.py +++ b/tests/test_saliency_map_models.py @@ -493,3 +493,36 @@ def test_conditional_saliency_maps(stimuli, fixation_trains): saliency_maps_2 = [model.conditional_saliency_map_for_fixation(stimuli, fixation_trains, i) for i in range(len(fixation_trains))] np.testing.assert_allclose(saliency_maps_1, saliency_maps_2) + + +def test_stimulus_dependent_saliency_map_model(stimuli, fixation_trains): + # Create stimulus models + stimulus_model_1 = ConstantSaliencyMapModel(value=0.5) + stimulus_model_2 = GaussianSaliencyMapModel() + + # Create the stimulus-dependent saliency map model + stimuli_models = {stimuli[[0]]: stimulus_model_1, stimuli[[1]]: stimulus_model_2} + fallback_model = ConstantSaliencyMapModel(value=0.2) + sdsmm = pysaliency.saliency_map_models.StimulusDependentSaliencyMapModel(stimuli_models, fallback_model=fallback_model) + + # Test saliency map for stimulus 1 + saliency_map_1 = sdsmm.saliency_map(stimuli[0]) + np.testing.assert_allclose(saliency_map_1, np.ones((40, 40)) * 0.5) + + # Test saliency map for stimulus 2 + saliency_map_2 = sdsmm.saliency_map(stimuli[1]) + height = stimuli[1].shape[0] + width = stimuli[1].shape[1] + expected_saliency_map_2 = np.exp(-0.5 * ((np.mgrid[:height, :width][1] - 0.5 * width) ** 2 + + (np.mgrid[:height, :width][0] - 0.5 * height) ** 2) / + np.sqrt(width ** 2 + height ** 2)) + np.testing.assert_allclose(saliency_map_2, expected_saliency_map_2) + + # Test fallback model + fallback_saliency_map = fallback_model.saliency_map(np.random.randn(50, 50, 3)) + np.testing.assert_allclose(fallback_saliency_map, np.ones((50, 50)) * 0.2) + + # Test saliency map for stimulus not provided by the models if there is no fallback model + sdsmm.fallback_model = None + with pytest.raises(ValueError): + sdsmm.saliency_map(np.random.randn(50, 50, 3))