From e50ab8ee83d1c9c6ef5dfa2af02118067d8a03b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Wed, 25 Sep 2024 17:00:27 +0200 Subject: [PATCH] pass keyword arguments to read_hdf5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/datasets/__init__.py | 20 ++++++++++---------- tests/datasets/test_stimuli.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pysaliency/datasets/__init__.py b/pysaliency/datasets/__init__.py index b9a1805..63aeb73 100644 --- a/pysaliency/datasets/__init__.py +++ b/pysaliency/datasets/__init__.py @@ -12,30 +12,30 @@ @cached(WeakValueDictionary()) -def _read_hdf5_from_file(source): +def _read_hdf5_from_file(source, **kwargs): import h5py with h5py.File(source, 'r') as hdf5_file: - return read_hdf5(hdf5_file) + return read_hdf5(hdf5_file, **kwargs) -def read_hdf5(source): +def read_hdf5(source, **kwargs): if isinstance(source, (str, pathlib.Path)): - return _read_hdf5_from_file(source) + return _read_hdf5_from_file(source, **kwargs) data_type = decode_string(source.attrs['type']) if data_type == 'Fixations': - return Fixations.read_hdf5(source) + return Fixations.read_hdf5(source, **kwargs) elif data_type == 'ScanpathFixations': - return ScanpathFixations.read_hdf5(source) + return ScanpathFixations.read_hdf5(source, **kwargs) elif data_type == 'FixationTrains': - return FixationTrains.read_hdf5(source) + return FixationTrains.read_hdf5(source, **kwargs) elif data_type == 'Scanpaths': - return Scanpaths.read_hdf5(source) + return Scanpaths.read_hdf5(source, **kwargs) elif data_type == 'Stimuli': - return Stimuli.read_hdf5(source) + return Stimuli.read_hdf5(source, **kwargs) elif data_type == 'FileStimuli': - return FileStimuli.read_hdf5(source) + return FileStimuli.read_hdf5(source, **kwargs) else: raise ValueError("Invalid HDF content type:", data_type) diff --git a/tests/datasets/test_stimuli.py b/tests/datasets/test_stimuli.py index 53d84e7..ef2a9b8 100644 --- a/tests/datasets/test_stimuli.py +++ b/tests/datasets/test_stimuli.py @@ -243,6 +243,19 @@ def test_file_stimuli_attributes(file_stimuli_with_attributes, tmp_path): assert list(np.array(file_stimuli_with_attributes.attributes['some_strings'])[mask]) == partial_stimuli.attributes['some_strings'] +def test_file_stimuli_readhdf5_cached(file_stimuli_with_attributes, tmp_path): + filename = tmp_path / 'stimuli.hdf5' + file_stimuli_with_attributes.to_hdf5(str(filename)) + + new_stimuli = pysaliency.read_hdf5(str(filename)) + + assert new_stimuli.cached + + new_stimuli2 = pysaliency.read_hdf5(str(filename), cached=False) + + assert not new_stimuli2.cached + + def test_concatenate_stimuli_with_attributes(stimuli_with_attributes, file_stimuli_with_attributes): concatenated_stimuli = pysaliency.datasets.concatenate_stimuli([stimuli_with_attributes, file_stimuli_with_attributes])