Skip to content

Commit

Permalink
Added functionality to use combined centerbias (#45)
Browse files Browse the repository at this point in the history
* Merge branch 'dev' of github.com:naman0210/pysaliency into dev

* Update test_precomputed_models.py

* Update tests/test_precomputed_models.py

added comment

* Update tests/test_precomputed_models.py

---------

Co-authored-by: Harneet Singh Khanuja <[email protected]>
Co-authored-by: matthias-k <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent e32e2af commit 332bfb6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
29 changes: 27 additions & 2 deletions pysaliency/precomputed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ def _log_density(self, stimulus):
return smap


def get_keys_recursive(group, prefix=''):
import h5py

keys = []

for subgroup_name, subgroup in group.items():
if isinstance(subgroup, h5py.Group):
subprefix = f"{prefix}{subgroup_name}/"
keys.extend(get_keys_recursive(subgroup, prefix=subprefix))
else:
keys.append(f"{prefix}{subgroup_name}")

return keys

def get_stimulus_key(stimulus_name, all_keys):
matching_keys = [key for key in all_keys if key.endswith(stimulus_name)]
if len(matching_keys) == 0:
raise ValueError(f"Stimulus {stimulus_name} not found in hdf5 file!")
elif len(matching_keys) > 1:
raise ValueError(f"Stimulus {stimulus_name} not unique in hdf5 file!")
return matching_keys[0]


class HDF5SaliencyMapModel(SaliencyMapModel):
""" exposes a HDF5 file with saliency maps as pysaliency model
Expand All @@ -203,15 +226,17 @@ def __init__(self, stimuli, filename, check_shape=True, **kwargs):

import h5py
self.hdf5_file = h5py.File(self.filename, 'r')
self.all_keys = get_keys_recursive(self.hdf5_file)

def _saliency_map(self, stimulus):
stimulus_id = get_image_hash(stimulus)
stimulus_index = self.stimuli.stimulus_ids.index(stimulus_id)
stimulus_filename = self.names[stimulus_index]
smap = self.hdf5_file[stimulus_filename][:]
stimulus_key = get_stimulus_key(stimulus_filename, self.all_keys)
smap = self.hdf5_file[stimulus_key][:]
if not smap.shape == (stimulus.shape[0], stimulus.shape[1]):
if self.check_shape:
warnings.warn('Wrong shape for stimulus {}'.format(stimulus_filename))
warnings.warn('Wrong shape for stimulus {}'.format(stimulus_key))
return smap


Expand Down
6 changes: 4 additions & 2 deletions tests/test_precomputed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
def file_stimuli(tmpdir):
filenames = []
for i in range(3):
filename = tmpdir.join('stimulus_{:04d}.png'.format(i))
# TODO: change back to stimulus_... once this is supported again
filename = tmpdir.join('_stimulus_{:04d}.png'.format(i))
imsave(str(filename), np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8))
filenames.append(str(filename))

Expand All @@ -36,7 +37,8 @@ def stimuli_with_filenames(tmpdir):
filenames = []
stimuli = []
for i in range(3):
filename = tmpdir.join('stimulus_{:04d}.png'.format(i))
# TODO: change back to stimulus_... once this is supported again
filename = tmpdir.join('_stimulus_{:04d}.png'.format(i))
stimuli.append(np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8))
filenames.append(str(filename))

Expand Down

0 comments on commit 332bfb6

Please sign in to comment.