From c2ea6de5aac47716da06a8bd48568f3a2d533afd Mon Sep 17 00:00:00 2001 From: matthias-k Date: Wed, 25 Sep 2024 17:22:12 +0200 Subject: [PATCH] pass keyword arguments to read_hdf5 (#86) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- CHANGELOG.md | 1 + pysaliency/datasets/__init__.py | 20 ++++++++++---------- tests/datasets/test_stimuli.py | 13 +++++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ef0a06..23fcf1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ * Feature: The RARE2007 model is now available as `pysaliency.external_models.RARE2007`. It's execution requires MATLAB. * matlab scripts are now called with the `-batch` option instead of `-nodisplay -nosplash -r`, which should behave better. * Enhancement: preloaded stimulus ids are passed on to subsets of Stimuli and FileStimuli. + * Feature: `pysaliency.read_hdf5` now takes additional keyword arguments which are passed to the respective class methods. This allows, e.g., to load `FileStimuli` with caching disabled. * 0.2.22: diff --git a/pysaliency/datasets/__init__.py b/pysaliency/datasets/__init__.py index b9a1805..63aeb73 100644 --- a/pysaliency/datasets/__init__.py +++ b/pysaliency/datasets/__init__.py @@ -12,30 +12,30 @@ @cached(WeakValueDictionary()) -def _read_hdf5_from_file(source): +def _read_hdf5_from_file(source, **kwargs): import h5py with h5py.File(source, 'r') as hdf5_file: - return read_hdf5(hdf5_file) + return read_hdf5(hdf5_file, **kwargs) -def read_hdf5(source): +def read_hdf5(source, **kwargs): if isinstance(source, (str, pathlib.Path)): - return _read_hdf5_from_file(source) + return _read_hdf5_from_file(source, **kwargs) data_type = decode_string(source.attrs['type']) if data_type == 'Fixations': - return Fixations.read_hdf5(source) + return Fixations.read_hdf5(source, **kwargs) elif data_type == 'ScanpathFixations': - return ScanpathFixations.read_hdf5(source) + return ScanpathFixations.read_hdf5(source, **kwargs) elif data_type == 'FixationTrains': - return FixationTrains.read_hdf5(source) + return FixationTrains.read_hdf5(source, **kwargs) elif data_type == 'Scanpaths': - return Scanpaths.read_hdf5(source) + return Scanpaths.read_hdf5(source, **kwargs) elif data_type == 'Stimuli': - return Stimuli.read_hdf5(source) + return Stimuli.read_hdf5(source, **kwargs) elif data_type == 'FileStimuli': - return FileStimuli.read_hdf5(source) + return FileStimuli.read_hdf5(source, **kwargs) else: raise ValueError("Invalid HDF content type:", data_type) diff --git a/tests/datasets/test_stimuli.py b/tests/datasets/test_stimuli.py index 53d84e7..ef2a9b8 100644 --- a/tests/datasets/test_stimuli.py +++ b/tests/datasets/test_stimuli.py @@ -243,6 +243,19 @@ def test_file_stimuli_attributes(file_stimuli_with_attributes, tmp_path): assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings'] +def test_file_stimuli_readhdf5_cached(file_stimuli_with_attributes, tmp_path): + filename = tmp_path / 'stimuli.hdf5' + file_stimuli_with_attributes.to_hdf5(str(filename)) + + new_stimuli = pysaliency.read_hdf5(str(filename)) + + assert new_stimuli.cached + + new_stimuli2 = pysaliency.read_hdf5(str(filename), cached=False) + + assert not new_stimuli2.cached + + def test_concatenate_stimuli_with_attributes(stimuli_with_attributes, file_stimuli_with_attributes): concatenated_stimuli = pysaliency.datasets.concatenate_stimuli([stimuli_with_attributes, file_stimuli_with_attributes])