Skip to content

Commit

Permalink
ENH: Subsets of stimuli take over existing stimulus ids
Browse files Browse the repository at this point in the history
When creating a subset of Stimuli, so far neither stimulus data
nor stimulus ids were taken over and had to be reloaded if required.
Now, at least stimulus ids are propagated which can save a lot of
memory, e.g. if splitting a large stimulus set into many small subsets.

Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Sep 21, 2024
1 parent 51611d8 commit 08bad9d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Bugfix!: The download location of the RARE2012 model changed. The new source code results in slightly different predictions.
* Feature: The RARE2007 model is now available as `pysaliency.external_models.RARE2007`. It's execution requires MATLAB.
* matlab scripts are now called with the `-batch` option instead of `-nodisplay -nosplash -r`, which should behave better.
* Enhancement: preloaded stimulus ids are passed on to subsets of Stimuli and FileStimuli.


* 0.2.22:
Expand Down
28 changes: 24 additions & 4 deletions pysaliency/datasets/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections.abc import Sequence
from hashlib import sha1
from typing import Union
from typing import List, Union

import numpy as np

Expand Down Expand Up @@ -152,7 +152,12 @@ def _get_attribute_for_stimulus_subset(self, index):
def __getitem__(self, index):
if isinstance(index, slice):
attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, range(len(self))[index])

return sub_stimuli
elif isinstance(index, (list, np.ndarray)):
index = np.asarray(index)
if index.dtype == bool:
Expand All @@ -161,10 +166,20 @@ def __getitem__(self, index):
index = np.nonzero(index)[0]

attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, index)

return sub_stimuli
else:
return self.stimulus_objects[index]

def _propagate_stimulus_ids(self, sub_stimuli: "Stimuli", index: List[int]):
for new_index, old_index in enumerate(index):
if old_index in self.stimulus_ids._cache:
sub_stimuli.stimulus_ids._cache[new_index] = self.stimulus_ids._cache[old_index]

@hdf5_wrapper(mode='w')
def to_hdf5(self, target, verbose=False, compression='gzip', compression_opts=9):
""" Write stimuli to hdf5 file or hdf5 group
Expand Down Expand Up @@ -343,7 +358,12 @@ def __getitem__(self, index):
filenames = [self.filenames[i] for i in index]
shapes = [self.shapes[i] for i in index]
attributes = self._get_attribute_for_stimulus_subset(index)
return type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)
sub_stimuli = type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, index)

return sub_stimuli
else:
return self.stimulus_objects[index]

Expand Down
21 changes: 20 additions & 1 deletion tests/datasets/test_stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,23 @@ def test_check_prediction_shape():
stimulus = Stimulus(np.random.rand(10, 11))
with pytest.raises(ValueError) as excinfo:
check_prediction_shape(prediction, stimulus)
assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"
assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"


@pytest.mark.parametrize(
'stimuli',
['stimuli_with_attributes', 'file_stimuli_with_attributes']
)
def test_substimuli_inherit_cachedstimulus_ids(stimuli, request):
_stimuli = request.getfixturevalue(stimuli)
# load some stimulus ids
cache_stimulus_indices = [1, 2, 5]
# make sure the ids are cached
for i in cache_stimulus_indices:
_stimuli.stimulus_ids[i]

assert len(_stimuli.stimulus_ids._cache) == len(cache_stimulus_indices)

sub_stimuli = _stimuli[1:5]
assert set(sub_stimuli.stimulus_ids._cache.keys()) == {0, 1}

0 comments on commit 08bad9d

Please sign in to comment.