From e87051898734f19b9feae97f1230110f11704255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Mon, 15 Apr 2024 13:53:31 +0200 Subject: [PATCH] HDF5 model should not fail if created with empty stimuli MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/precomputed_models.py | 3 +++ tests/test_precomputed_models.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/pysaliency/precomputed_models.py b/pysaliency/precomputed_models.py index 69d1c7b..795e5b7 100644 --- a/pysaliency/precomputed_models.py +++ b/pysaliency/precomputed_models.py @@ -30,6 +30,9 @@ def get_stimuli_filenames(stimuli): def get_keys_from_filenames(filenames, keys): """checks how much filenames have to be shorted to get the correct hdf5 or other keys""" + if not filenames: + return [] + first_filename_parts = full_split(filenames[0]) for part_index in range(len(first_filename_parts)): remaining_filename = os.path.join(*first_filename_parts[part_index:]) diff --git a/tests/test_precomputed_models.py b/tests/test_precomputed_models.py index ce195ce..24e50e0 100644 --- a/tests/test_precomputed_models.py +++ b/tests/test_precomputed_models.py @@ -113,6 +113,16 @@ def test_hdf5_model_sub_stimuli(stimuli, sub_stimuli, tmpdir): np.testing.assert_allclose(model.log_density(s), model2.log_density(s)) +def test_hdf5_model_empty_stimuli(stimuli, tmpdir): + model = pysaliency.models.SaliencyMapNormalizingModel(TestSaliencyMapModel()) + filename = str(tmpdir.join('model.hdf5')) + export_model_to_hdf5(model, stimuli, filename) + + sub_stimuli = stimuli[[]] + + pysaliency.HDF5Model(sub_stimuli, filename) + + def test_export_model_overwrite(file_stimuli, tmpdir): model1 = pysaliency.GaussianSaliencyMapModel(width=0.1) model2 = pysaliency.GaussianSaliencyMapModel(width=0.8)