diff --git a/pysaliency/precomputed_models.py b/pysaliency/precomputed_models.py index 8dc3bed..8fbb0d4 100644 --- a/pysaliency/precomputed_models.py +++ b/pysaliency/precomputed_models.py @@ -183,6 +183,29 @@ def _log_density(self, stimulus): return smap +def get_keys_recursive(group, prefix=''): + import h5py + + keys = [] + + for subgroup_name, subgroup in group.items(): + if isinstance(subgroup, h5py.Group): + subprefix = f"{prefix}{subgroup_name}/" + keys.extend(get_keys_recursive(subgroup, prefix=subprefix)) + else: + keys.append(f"{prefix}{subgroup_name}") + + return keys + +def get_stimulus_key(stimulus_name, all_keys): + matching_keys = [key for key in all_keys if key.endswith(stimulus_name)] + if len(matching_keys) == 0: + raise ValueError(f"Stimulus {stimulus_name} not found in hdf5 file!") + elif len(matching_keys) > 1: + raise ValueError(f"Stimulus {stimulus_name} not unique in hdf5 file!") + return matching_keys[0] + + class HDF5SaliencyMapModel(SaliencyMapModel): """ exposes a HDF5 file with saliency maps as pysaliency model @@ -203,15 +226,17 @@ def __init__(self, stimuli, filename, check_shape=True, **kwargs): import h5py self.hdf5_file = h5py.File(self.filename, 'r') + self.all_keys = get_keys_recursive(self.hdf5_file) def _saliency_map(self, stimulus): stimulus_id = get_image_hash(stimulus) stimulus_index = self.stimuli.stimulus_ids.index(stimulus_id) stimulus_filename = self.names[stimulus_index] - smap = self.hdf5_file[stimulus_filename][:] + stimulus_key = get_stimulus_key(stimulus_filename, self.all_keys) + smap = self.hdf5_file[stimulus_key][:] if not smap.shape == (stimulus.shape[0], stimulus.shape[1]): if self.check_shape: - warnings.warn('Wrong shape for stimulus {}'.format(stimulus_filename)) + warnings.warn('Wrong shape for stimulus {}'.format(stimulus_key)) return smap diff --git a/tests/test_precomputed_models.py b/tests/test_precomputed_models.py index 66ab172..1b25a57 100644 --- a/tests/test_precomputed_models.py +++ b/tests/test_precomputed_models.py @@ -17,7 +17,8 @@ def file_stimuli(tmpdir): filenames = [] for i in range(3): - filename = tmpdir.join('stimulus_{:04d}.png'.format(i)) + # TODO: change back to stimulus_... once this is supported again + filename = tmpdir.join('_stimulus_{:04d}.png'.format(i)) imsave(str(filename), np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8)) filenames.append(str(filename)) @@ -36,7 +37,8 @@ def stimuli_with_filenames(tmpdir): filenames = [] stimuli = [] for i in range(3): - filename = tmpdir.join('stimulus_{:04d}.png'.format(i)) + # TODO: change back to stimulus_... once this is supported again + filename = tmpdir.join('_stimulus_{:04d}.png'.format(i)) stimuli.append(np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8)) filenames.append(str(filename))