Skip to content

Commit

Permalink
Stimuli can be filtered by boolean masks (#54)
Browse files Browse the repository at this point in the history
* Stimuli can be filtered by boolean masks

Signed-off-by: Matthias Kümmmerer <[email protected]>

* fix typo

Signed-off-by: Matthias Kümmmerer <[email protected]>

---------

Signed-off-by: Matthias Kümmmerer <[email protected]>
  • Loading branch information
matthias-k authored Mar 8, 2024
1 parent 64bcde4 commit e2029a1
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
run: |
conda install pytest hypothesis
python setup.py build_ext --inplace
python -m pytest --nomatlab --notheano tests
python -m pytest --nomatlab --notheano --nodownload tests
- name: test build and install
shell: bash -el {0}
run: |
Expand Down
14 changes: 13 additions & 1 deletion pysaliency/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,13 @@ def __getitem__(self, index):
if isinstance(index, slice):
attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
elif isinstance(index, list):
elif isinstance(index, (list, np.ndarray)):
index = np.asarray(index)
if index.dtype == bool:
if not len(index) == len(self.stimuli):
raise ValueError(f"Boolean index has to have the same length as the stimuli list but got {len(index)} and {len(self.stimuli)}")
index = np.nonzero(index)[0]

attributes = self._get_attribute_for_stimulus_subset(index)
return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
else:
Expand Down Expand Up @@ -1345,6 +1351,12 @@ def __getitem__(self, index):
index = list(range(len(self)))[index]

if isinstance(index, (list, np.ndarray)):
index = np.asarray(index)
if index.dtype == bool:
if not len(index) == len(self.stimuli):
raise ValueError(f"Boolean index has to have the same length as the stimuli list but got {len(index)} and {len(self.stimuli)}")
index = np.nonzero(index)[0]

filenames = [self.filenames[i] for i in index]
shapes = [self.shapes[i] for i in index]
attributes = self._get_attribute_for_stimulus_subset(index)
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@ def pytest_addoption(parser):
parser.addoption("--nomatlab", action="store_true", default=False, help="don't run matlab tests")
parser.addoption("--nooctave", action="store_true", default=False, help="don't run octave tests")
parser.addoption("--notheano", action="store_true", default=False, help="don't run slow theano tests")
parser.addoption("--nodownload", action="store_true", default=False, help="don't download external data")


def pytest_collection_modifyitems(config, items):
run_slow = config.getoption("--runslow")
run_nonfree = config.getoption('--run-nonfree')
no_matlab = config.getoption("--nomatlab")
no_theano = config.getoption("--notheano")
no_download = config.getoption("--nodownload")
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
skip_nonfree = pytest.mark.skip(reason="need --run-nonfree option to run")
skip_matlab = pytest.mark.skip(reason="skipped because of --nomatlab")
skip_theano = pytest.mark.skip(reason="skipped because of --notheano")
skip_download = pytest.mark.skip(reason="skipped because of --nodownload")
for item in items:
if "slow" in item.keywords and not run_slow:
item.add_marker(skip_slow)
Expand All @@ -34,6 +37,8 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_matlab)
if "theano" in item.keywords and no_theano:
item.add_marker(skip_theano)
if "download" in item.keywords and no_download:
item.add_marker(skip_download)


@pytest.fixture(params=["matlab", "octave"])
Expand Down
3 changes: 2 additions & 1 deletion tests/external_models/test_deepgaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ def fixations():
)


@pytest.mark.download
def test_deepgaze1(stimuli, fixations):
model = DeepGazeI(centerbias_model=pysaliency.UniformModel(), device='cpu')

ig = model.information_gain(stimuli, fixations)

np.testing.assert_allclose(ig, 0.9455161648442227, rtol=5e-6)


@pytest.mark.download
def test_deepgaze2e(stimuli, fixations):
model = DeepGazeIIE(centerbias_model=pysaliency.UniformModel(), device='cpu')

Expand Down
22 changes: 22 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,17 @@ def test_stimuli_attributes(stimuli_with_attributes, tmp_path):
assert list(np.array(stimuli_with_attributes.attributes['dva'])[[1, 2, 6]]) == partial_stimuli.attributes['dva']
assert list(np.array(stimuli_with_attributes.attributes['some_strings'])[[1, 2, 6]]) == partial_stimuli.attributes['some_strings']

mask = np.array([True, False, True, False, True, False, True, False, True, False, True, False])
with pytest.raises(ValueError):
partial_stimuli = stimuli_with_attributes[mask]

mask = np.array([True, False, True, False, True, False, True, False, True, False])
partial_stimuli = stimuli_with_attributes[mask]
assert stimuli_with_attributes.attributes.keys() == partial_stimuli.attributes.keys()
assert list(np.array(stimuli_with_attributes.attributes['dva'])[mask]) == partial_stimuli.attributes['dva']
assert list(np.array(stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings']



@pytest.fixture
def file_stimuli_with_attributes(tmpdir):
Expand Down Expand Up @@ -611,6 +622,17 @@ def test_file_stimuli_attributes(file_stimuli_with_attributes, tmp_path):
assert list(np.array(file_stimuli_with_attributes.attributes['dva'])[[1, 2, 6]]) == partial_stimuli.attributes['dva']
assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[[1, 2, 6]]) == partial_stimuli.attributes['some_strings']

mask = np.array([True, False, True, False, True, False, True, False, True, False])
with pytest.raises(ValueError):
partial_stimuli = file_stimuli_with_attributes[mask]

mask = np.array([True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False])
partial_stimuli = file_stimuli_with_attributes[mask]

assert file_stimuli_with_attributes.attributes.keys() == partial_stimuli.attributes.keys()
assert list(np.array(file_stimuli_with_attributes.attributes['dva'])[mask]) == partial_stimuli.attributes['dva']
assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings']


def test_concatenate_stimuli_with_attributes(stimuli_with_attributes, file_stimuli_with_attributes):
concatenated_stimuli = pysaliency.datasets.concatenate_stimuli([stimuli_with_attributes, file_stimuli_with_attributes])
Expand Down

0 comments on commit e2029a1

Please sign in to comment.