From 08bad9dd69528f277d6236cac7619c1847cfc608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= <matthias@matthias-k.org> Date: Sat, 21 Sep 2024 21:11:05 +0200 Subject: [PATCH] ENH: Subsets of stimuli take over existing stimulus ids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When creating a subset of Stimuli, so far neither stimulus data nor stimulus ids were taken over and had to be reloaded if required. Now, at least stimulus ids are propagated which can save a lot of memory, e.g. if splitting a large stimulus set into many small subsets. Signed-off-by: Matthias Kümmerer <matthias@matthias-k.org> --- CHANGELOG.md | 1 + pysaliency/datasets/stimuli.py | 28 ++++++++++++++++++++++++---- tests/datasets/test_stimuli.py | 21 ++++++++++++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55c50ac..2ef0a06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ * Bugfix!: The download location of the RARE2012 model changed. The new source code results in slightly different predictions. * Feature: The RARE2007 model is now available as `pysaliency.external_models.RARE2007`. It's execution requires MATLAB. * matlab scripts are now called with the `-batch` option instead of `-nodisplay -nosplash -r`, which should behave better. + * Enhancement: preloaded stimulus ids are passed on to subsets of Stimuli and FileStimuli. * 0.2.22: diff --git a/pysaliency/datasets/stimuli.py b/pysaliency/datasets/stimuli.py index 2de5ce5..0029d90 100644 --- a/pysaliency/datasets/stimuli.py +++ b/pysaliency/datasets/stimuli.py @@ -2,7 +2,7 @@ import os from collections.abc import Sequence from hashlib import sha1 -from typing import Union +from typing import List, Union import numpy as np @@ -152,7 +152,12 @@ def _get_attribute_for_stimulus_subset(self, index): def __getitem__(self, index): if isinstance(index, slice): attributes = self._get_attribute_for_stimulus_subset(index) - return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes) + sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes) + + # populate stimulus_id cache with existing entries + self._propagate_stimulus_ids(sub_stimuli, range(len(self))[index]) + + return sub_stimuli elif isinstance(index, (list, np.ndarray)): index = np.asarray(index) if index.dtype == bool: @@ -161,10 +166,20 @@ def __getitem__(self, index): index = np.nonzero(index)[0] attributes = self._get_attribute_for_stimulus_subset(index) - return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes) + sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes) + + # populate stimulus_id cache with existing entries + self._propagate_stimulus_ids(sub_stimuli, index) + + return sub_stimuli else: return self.stimulus_objects[index] + def _propagate_stimulus_ids(self, sub_stimuli: "Stimuli", index: List[int]): + for new_index, old_index in enumerate(index): + if old_index in self.stimulus_ids._cache: + sub_stimuli.stimulus_ids._cache[new_index] = self.stimulus_ids._cache[old_index] + @hdf5_wrapper(mode='w') def to_hdf5(self, target, verbose=False, compression='gzip', compression_opts=9): """ Write stimuli to hdf5 file or hdf5 group @@ -343,7 +358,12 @@ def __getitem__(self, index): filenames = [self.filenames[i] for i in index] shapes = [self.shapes[i] for i in index] attributes = self._get_attribute_for_stimulus_subset(index) - return type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached) + sub_stimuli = type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached) + + # populate stimulus_id cache with existing entries + self._propagate_stimulus_ids(sub_stimuli, index) + + return sub_stimuli else: return self.stimulus_objects[index] diff --git a/tests/datasets/test_stimuli.py b/tests/datasets/test_stimuli.py index 6f63c32..53d84e7 100644 --- a/tests/datasets/test_stimuli.py +++ b/tests/datasets/test_stimuli.py @@ -291,4 +291,23 @@ def test_check_prediction_shape(): stimulus = Stimulus(np.random.rand(10, 11)) with pytest.raises(ValueError) as excinfo: check_prediction_shape(prediction, stimulus) - assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)" \ No newline at end of file + assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)" + + +@pytest.mark.parametrize( + 'stimuli', + ['stimuli_with_attributes', 'file_stimuli_with_attributes'] +) +def test_substimuli_inherit_cachedstimulus_ids(stimuli, request): + _stimuli = request.getfixturevalue(stimuli) + # load some stimulus ids + cache_stimulus_indices = [1, 2, 5] + # make sure the ids are cached + for i in cache_stimulus_indices: + _stimuli.stimulus_ids[i] + + assert len(_stimuli.stimulus_ids._cache) == len(cache_stimulus_indices) + + sub_stimuli = _stimuli[1:5] + assert set(sub_stimuli.stimulus_ids._cache.keys()) == {0, 1} +