From e2029a143344e81317230cdefd8d6e4e5f47c7d4 Mon Sep 17 00:00:00 2001 From: matthias-k Date: Fri, 8 Mar 2024 14:52:37 +0100 Subject: [PATCH] Stimuli can be filtered by boolean masks (#54) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Stimuli can be filtered by boolean masks Signed-off-by: Matthias Kümmmerer * fix typo Signed-off-by: Matthias Kümmmerer --------- Signed-off-by: Matthias Kümmmerer --- .github/workflows/test-package-conda.yml | 2 +- pysaliency/datasets.py | 14 +++++++++++++- tests/conftest.py | 5 +++++ tests/external_models/test_deepgaze.py | 3 ++- tests/test_datasets.py | 22 ++++++++++++++++++++++ 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-package-conda.yml b/.github/workflows/test-package-conda.yml index f862467..d4125d9 100644 --- a/.github/workflows/test-package-conda.yml +++ b/.github/workflows/test-package-conda.yml @@ -61,7 +61,7 @@ jobs: run: | conda install pytest hypothesis python setup.py build_ext --inplace - python -m pytest --nomatlab --notheano tests + python -m pytest --nomatlab --notheano --nodownload tests - name: test build and install shell: bash -el {0} run: | diff --git a/pysaliency/datasets.py b/pysaliency/datasets.py index e636c7e..a158c12 100644 --- a/pysaliency/datasets.py +++ b/pysaliency/datasets.py @@ -1170,7 +1170,13 @@ def __getitem__(self, index): if isinstance(index, slice): attributes = self._get_attribute_for_stimulus_subset(index) return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes) - elif isinstance(index, list): + elif isinstance(index, (list, np.ndarray)): + index = np.asarray(index) + if index.dtype == bool: + if not len(index) == len(self.stimuli): + raise ValueError(f"Boolean index has to have the same length as the stimuli list but got {len(index)} and {len(self.stimuli)}") + index = np.nonzero(index)[0] + attributes = self._get_attribute_for_stimulus_subset(index) return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes) else: @@ -1345,6 +1351,12 @@ def __getitem__(self, index): index = list(range(len(self)))[index] if isinstance(index, (list, np.ndarray)): + index = np.asarray(index) + if index.dtype == bool: + if not len(index) == len(self.stimuli): + raise ValueError(f"Boolean index has to have the same length as the stimuli list but got {len(index)} and {len(self.stimuli)}") + index = np.nonzero(index)[0] + filenames = [self.filenames[i] for i in index] shapes = [self.shapes[i] for i in index] attributes = self._get_attribute_for_stimulus_subset(index) diff --git a/tests/conftest.py b/tests/conftest.py index 1a6b389..1a3045b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ def pytest_addoption(parser): parser.addoption("--nomatlab", action="store_true", default=False, help="don't run matlab tests") parser.addoption("--nooctave", action="store_true", default=False, help="don't run octave tests") parser.addoption("--notheano", action="store_true", default=False, help="don't run slow theano tests") + parser.addoption("--nodownload", action="store_true", default=False, help="don't download external data") def pytest_collection_modifyitems(config, items): @@ -21,10 +22,12 @@ def pytest_collection_modifyitems(config, items): run_nonfree = config.getoption('--run-nonfree') no_matlab = config.getoption("--nomatlab") no_theano = config.getoption("--notheano") + no_download = config.getoption("--nodownload") skip_slow = pytest.mark.skip(reason="need --runslow option to run") skip_nonfree = pytest.mark.skip(reason="need --run-nonfree option to run") skip_matlab = pytest.mark.skip(reason="skipped because of --nomatlab") skip_theano = pytest.mark.skip(reason="skipped because of --notheano") + skip_download = pytest.mark.skip(reason="skipped because of --nodownload") for item in items: if "slow" in item.keywords and not run_slow: item.add_marker(skip_slow) @@ -34,6 +37,8 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_matlab) if "theano" in item.keywords and no_theano: item.add_marker(skip_theano) + if "download" in item.keywords and no_download: + item.add_marker(skip_download) @pytest.fixture(params=["matlab", "octave"]) diff --git a/tests/external_models/test_deepgaze.py b/tests/external_models/test_deepgaze.py index bb99b39..72faddf 100644 --- a/tests/external_models/test_deepgaze.py +++ b/tests/external_models/test_deepgaze.py @@ -33,6 +33,7 @@ def fixations(): ) +@pytest.mark.download def test_deepgaze1(stimuli, fixations): model = DeepGazeI(centerbias_model=pysaliency.UniformModel(), device='cpu') @@ -40,7 +41,7 @@ def test_deepgaze1(stimuli, fixations): np.testing.assert_allclose(ig, 0.9455161648442227, rtol=5e-6) - +@pytest.mark.download def test_deepgaze2e(stimuli, fixations): model = DeepGazeIIE(centerbias_model=pysaliency.UniformModel(), device='cpu') diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 93ce8d2..7e39e65 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -566,6 +566,17 @@ def test_stimuli_attributes(stimuli_with_attributes, tmp_path): assert list(np.array(stimuli_with_attributes.attributes['dva'])[[1, 2, 6]]) == partial_stimuli.attributes['dva'] assert list(np.array(stimuli_with_attributes.attributes['some_strings'])[[1, 2, 6]]) == partial_stimuli.attributes['some_strings'] + mask = np.array([True, False, True, False, True, False, True, False, True, False, True, False]) + with pytest.raises(ValueError): + partial_stimuli = stimuli_with_attributes[mask] + + mask = np.array([True, False, True, False, True, False, True, False, True, False]) + partial_stimuli = stimuli_with_attributes[mask] + assert stimuli_with_attributes.attributes.keys() == partial_stimuli.attributes.keys() + assert list(np.array(stimuli_with_attributes.attributes['dva'])[mask]) == partial_stimuli.attributes['dva'] + assert list(np.array(stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings'] + + @pytest.fixture def file_stimuli_with_attributes(tmpdir): @@ -611,6 +622,17 @@ def test_file_stimuli_attributes(file_stimuli_with_attributes, tmp_path): assert list(np.array(file_stimuli_with_attributes.attributes['dva'])[[1, 2, 6]]) == partial_stimuli.attributes['dva'] assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[[1, 2, 6]]) == partial_stimuli.attributes['some_strings'] + mask = np.array([True, False, True, False, True, False, True, False, True, False]) + with pytest.raises(ValueError): + partial_stimuli = file_stimuli_with_attributes[mask] + + mask = np.array([True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False]) + partial_stimuli = file_stimuli_with_attributes[mask] + + assert file_stimuli_with_attributes.attributes.keys() == partial_stimuli.attributes.keys() + assert list(np.array(file_stimuli_with_attributes.attributes['dva'])[mask]) == partial_stimuli.attributes['dva'] + assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings'] + def test_concatenate_stimuli_with_attributes(stimuli_with_attributes, file_stimuli_with_attributes): concatenated_stimuli = pysaliency.datasets.concatenate_stimuli([stimuli_with_attributes, file_stimuli_with_attributes])