diff --git a/pysaliency/datasets.py b/pysaliency/datasets.py index 856a460..7226d81 100644 --- a/pysaliency/datasets.py +++ b/pysaliency/datasets.py @@ -8,6 +8,7 @@ import json import os import pathlib +from typing import Union import warnings from weakref import WeakValueDictionary @@ -1045,13 +1046,6 @@ def get_image_hash(img): return sha1(np.ascontiguousarray(img)).hexdigest() -def as_stimulus(img_or_stimulus): - if isinstance(img_or_stimulus, Stimulus): - return img_or_stimulus - - return Stimulus(img_or_stimulus) - - class Stimulus(object): """ Manages a stimulus. @@ -1087,6 +1081,13 @@ def size(self): return self._size +def as_stimulus(img_or_stimulus: Union[np.ndarray, Stimulus]) -> Stimulus: + if isinstance(img_or_stimulus, Stimulus): + return img_or_stimulus + + return Stimulus(img_or_stimulus) + + class StimuliStimulus(Stimulus): """ Stimulus bound to a Stimuli object @@ -1776,3 +1777,10 @@ def _load_attribute_dict_from_hdf5(attribute_group): attributes = {attribute: attribute_group[attribute][...] for attribute in __attributes__} return attributes + + +def check_prediction_shape(prediction: np.ndarray, stimulus: Union[np.ndarray, Stimulus]): + stimulus = as_stimulus(stimulus) + + if prediction.shape != stimulus.size: + raise ValueError(f"Prediction shape {prediction.shape} does not match stimulus shape {stimulus.size}") \ No newline at end of file diff --git a/pysaliency/models.py b/pysaliency/models.py index 7c02fda..0366151 100755 --- a/pysaliency/models.py +++ b/pysaliency/models.py @@ -16,7 +16,7 @@ DisjointUnionMixin, GaussianSaliencyMapModel, ) -from .datasets import FixationTrains, get_image_hash, as_stimulus +from .datasets import FixationTrains, check_prediction_shape, get_image_hash, as_stimulus from .metrics import probabilistic_image_based_kl_divergence, convert_saliency_map_to_density from .sampling_models import SamplingModelMixin from .utils import Cache, average_values, deprecated_class, remove_trailing_nans, iterator_chunks @@ -155,6 +155,7 @@ def log_likelihoods(self, stimuli, fixations, verbose=False): log_likelihoods = np.empty(len(fixations.x)) for i in tqdm(range(len(fixations.x)), disable=not verbose): conditional_log_density = self.conditional_log_density_for_fixation(stimuli, fixations, i) + check_prediction_shape(conditional_log_density, stimuli[fixations.n[i]]) log_likelihoods[i] = conditional_log_density[fixations.y_int[i], fixations.x_int[i]] return log_likelihoods @@ -331,7 +332,8 @@ def log_likelihoods(self, stimuli, fixations, verbose=False): inds = fixations.n == n if not inds.sum(): continue - log_density = self.log_density(stimuli.stimulus_objects[n]) + log_density = self.log_density(stimuli[n]) + check_prediction_shape(log_density, stimuli[n]) this_log_likelihoods = log_density[fixations.y_int[inds], fixations.x_int[inds]] log_likelihoods[inds] = this_log_likelihoods @@ -372,6 +374,8 @@ def kl_divergences(self, stimuli, gold_standard, log_regularization=0, quotient_ for s in tqdm(stimuli, disable=not verbose): logp_model = self.log_density(s) logp_gold = gold_standard.log_density(s) + check_prediction_shape(logp_model, s) + check_prediction_shape(logp_gold, s) kl_divs.append( probabilistic_image_based_kl_divergence(logp_model, logp_gold, log_regularization=log_regularization, quotient_regularization=quotient_regularization) ) @@ -380,9 +384,9 @@ def kl_divergences(self, stimuli, gold_standard, log_regularization=0, quotient_ def set_params(self, **kwargs): """ - Set model parameters, if the model has parameters + Set model parameters, if the model has parameters - This method has to reset caches etc., if the depend on the parameters + This method has to reset caches etc., if the depend on the parameters """ if kwargs: raise ValueError('Unkown parameters!', kwargs) diff --git a/pysaliency/saliency_map_models.py b/pysaliency/saliency_map_models.py index 236e614..b02c0ac 100644 --- a/pysaliency/saliency_map_models.py +++ b/pysaliency/saliency_map_models.py @@ -16,7 +16,7 @@ from .numba_utils import fill_fixation_map, auc_for_one_positive from .utils import TemporaryDirectory, run_matlab_cmd, Cache, average_values, deprecated_class, remove_trailing_nans -from .datasets import Stimulus, Fixations, get_image_hash +from .datasets import Stimulus, Fixations, check_prediction_shape, get_image_hash from .metrics import CC, NSS, SIM from .sampling_models import SamplingModelMixin @@ -155,6 +155,8 @@ def AUCs(self, stimuli, fixations, nonfixations='uniform', verbose=False): for i in tqdm(range(len(fixations.x)), total=len(fixations.x), disable=not verbose): out = self.conditional_saliency_map_for_fixation(stimuli, fixations, i, out=out) + check_prediction_shape(out, stimuli[fixations.n[i]]) + positive = out[fixations.y_int[i], fixations.x_int[i]] if nonfixations == 'uniform': negatives = out.flatten() @@ -220,6 +222,7 @@ def NSSs(self, stimuli, fixations, verbose=False): for i in tqdm(range(len(fixations.x)), disable=not verbose, total=len(fixations.x)): out = self.conditional_saliency_map_for_fixation(stimuli, fixations, i, out=out) + check_prediction_shape(out, stimuli[fixations.n[i]]) values[i] = NSS(out, fixations.x_int[i], fixations.y_int[i]) return values @@ -331,6 +334,7 @@ def AUCs(self, stimuli, fixations, nonfixations='uniform', verbose=False): if not inds.sum(): continue out = self.saliency_map(stimuli.stimulus_objects[n]) + check_prediction_shape(out, stimuli[n]) positives = np.asarray(out[fixations.y_int[inds], fixations.x_int[inds]]) if nonfixations == 'uniform': negatives = out.flatten() @@ -407,6 +411,7 @@ def AUC_per_image(self, stimuli, fixations, nonfixations='uniform', thresholds=' for n in tqdm(range(len(stimuli)), disable=not verbose): out = self.saliency_map(stimuli.stimulus_objects[n]) + check_prediction_shape(out, stimuli[n]) inds = fixations.n == n positives = np.asarray(out[fixations.y_int[inds], fixations.x_int[inds]]) if nonfixations == 'uniform': @@ -533,7 +538,8 @@ def fixation_based_KL_divergence(self, stimuli, fixations, nonfixations='shuffle saliency_max = -np.inf for n in range(len(stimuli.stimuli)): - saliency_map = self.saliency_map(stimuli.stimulus_objects[n]) + saliency_map = self.saliency_map(stimuli[n]) + check_prediction_shape(saliency_map, stimuli[n]) saliency_min = min(saliency_min, saliency_map.min()) saliency_max = max(saliency_max, saliency_map.max()) @@ -631,7 +637,13 @@ def CCs(self, stimuli, other, verbose=False): coeffs = [] for s in tqdm(stimuli, disable=not verbose): - coeffs.append(CC(self.saliency_map(s), other.saliency_map(s))) + saliency_map_self = self.saliency_map(s) + saliency_map_other = other.saliency_map(s) + + check_prediction_shape(saliency_map_self, s) + check_prediction_shape(saliency_map_other, s) + + coeffs.append(CC(saliency_map_self, saliency_map_other)) return np.asarray(coeffs) @@ -645,6 +657,7 @@ def NSSs(self, stimuli, fixations, verbose=False): if not inds.sum(): continue smap = self.saliency_map(s).copy() + check_prediction_shape(smap, s) values[inds] = NSS(smap, fixations.x_int[inds], fixations.y_int[inds]) return values @@ -660,6 +673,10 @@ def SIMs(self, stimuli, other, verbose=False): for s in tqdm(stimuli, disable=not verbose): smap1 = self.saliency_map(s) smap2 = other.saliency_map(s) + + check_prediction_shape(smap1, s) + check_prediction_shape(smap2, s) + values.append(SIM(smap1, smap2)) return np.asarray(values) diff --git a/pysaliency/utils.py b/pysaliency/utils.py index 7aa6f69..407d4e3 100644 --- a/pysaliency/utils.py +++ b/pysaliency/utils.py @@ -505,4 +505,4 @@ def iterator_chunks(iterable, chunk_size=10): counter = count() for _, g in groupby(iterable, lambda _: next(counter) // chunk_size): - yield g + yield g \ No newline at end of file diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7e39e65..610f55a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,19 +1,19 @@ -from __future__ import absolute_import, print_function, division +from __future__ import absolute_import, division, print_function -import unittest import os.path -import dill import pickle -import pytest +import unittest +import dill import numpy as np +import pytest +from hypothesis import given +from hypothesis import strategies as st from imageio import imwrite - -from hypothesis import given, strategies as st +from test_helpers import TestWithData import pysaliency -from pysaliency.datasets import FixationTrains, Fixations, scanpaths_from_fixations -from test_helpers import TestWithData +from pysaliency.datasets import Fixations, FixationTrains, Stimulus, check_prediction_shape, scanpaths_from_fixations def compare_fixations_subset(f1, f2, f2_inds): @@ -780,5 +780,40 @@ def test_scanpaths_from_fixations(fixation_indices): compare_fixations(sub_fixations, new_sub_fixations, crop_length=True) +def test_check_prediction_shape(): + # Test with matching shapes + prediction = np.random.rand(10, 10) + stimulus = np.random.rand(10, 10) + check_prediction_shape(prediction, stimulus) # Should not raise any exception + + # Test with matching shapes, colorimage + prediction = np.random.rand(10, 10) + stimulus = np.random.rand(10, 10, 3) + check_prediction_shape(prediction, stimulus) # Should not raise any exception + + # Test with mismatching shapes + prediction = np.random.rand(10, 10) + stimulus = np.random.rand(10, 11) + with pytest.raises(ValueError) as excinfo: + check_prediction_shape(prediction, stimulus) + assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)" + + # Test with Stimulus object + prediction = np.random.rand(10, 10) + stimulus = Stimulus(np.random.rand(10, 10)) + check_prediction_shape(prediction, stimulus) # Should not raise any exception + + # Test with Stimulus object + prediction = np.random.rand(10, 10) + stimulus = Stimulus(np.random.rand(10, 10, 3)) + check_prediction_shape(prediction, stimulus) # Should not raise any exception + + # Test with mismatching shapes and Stimulus object + prediction = np.random.rand(10, 10) + stimulus = Stimulus(np.random.rand(10, 11)) + with pytest.raises(ValueError) as excinfo: + check_prediction_shape(prediction, stimulus) + assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)" + if __name__ == '__main__': unittest.main()