Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Subsets of stimuli take over existing stimulus ids #84

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Bugfix!: The download location of the RARE2012 model changed. The new source code results in slightly different predictions.
* Feature: The RARE2007 model is now available as `pysaliency.external_models.RARE2007`. It's execution requires MATLAB.
* matlab scripts are now called with the `-batch` option instead of `-nodisplay -nosplash -r`, which should behave better.
* Enhancement: preloaded stimulus ids are passed on to subsets of Stimuli and FileStimuli.


* 0.2.22:
Expand Down
28 changes: 24 additions & 4 deletions pysaliency/datasets/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections.abc import Sequence
from hashlib import sha1
from typing import Union
from typing import List, Union

import numpy as np

Expand Down Expand Up @@ -152,7 +152,12 @@ def _get_attribute_for_stimulus_subset(self, index):
def __getitem__(self, index):
if isinstance(index, slice):
attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, range(len(self))[index])

return sub_stimuli
elif isinstance(index, (list, np.ndarray)):
index = np.asarray(index)
if index.dtype == bool:
Expand All @@ -161,10 +166,20 @@ def __getitem__(self, index):
index = np.nonzero(index)[0]

attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, index)

return sub_stimuli
else:
return self.stimulus_objects[index]

def _propagate_stimulus_ids(self, sub_stimuli: "Stimuli", index: List[int]):
for new_index, old_index in enumerate(index):
if old_index in self.stimulus_ids._cache:
sub_stimuli.stimulus_ids._cache[new_index] = self.stimulus_ids._cache[old_index]

@hdf5_wrapper(mode='w')
def to_hdf5(self, target, verbose=False, compression='gzip', compression_opts=9):
""" Write stimuli to hdf5 file or hdf5 group
Expand Down Expand Up @@ -343,7 +358,12 @@ def __getitem__(self, index):
filenames = [self.filenames[i] for i in index]
shapes = [self.shapes[i] for i in index]
attributes = self._get_attribute_for_stimulus_subset(index)
return type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)
sub_stimuli = type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)

# populate stimulus_id cache with existing entries
self._propagate_stimulus_ids(sub_stimuli, index)

return sub_stimuli
else:
return self.stimulus_objects[index]

Expand Down
21 changes: 20 additions & 1 deletion tests/datasets/test_stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,23 @@ def test_check_prediction_shape():
stimulus = Stimulus(np.random.rand(10, 11))
with pytest.raises(ValueError) as excinfo:
check_prediction_shape(prediction, stimulus)
assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"
assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"


@pytest.mark.parametrize(
'stimuli',
['stimuli_with_attributes', 'file_stimuli_with_attributes']
)
def test_substimuli_inherit_cachedstimulus_ids(stimuli, request):
_stimuli = request.getfixturevalue(stimuli)
# load some stimulus ids
cache_stimulus_indices = [1, 2, 5]
# make sure the ids are cached
for i in cache_stimulus_indices:
_stimuli.stimulus_ids[i]

assert len(_stimuli.stimulus_ids._cache) == len(cache_stimulus_indices)

sub_stimuli = _stimuli[1:5]
assert set(sub_stimuli.stimulus_ids._cache.keys()) == {0, 1}

Loading