Skip to content

Commit

Permalink
Adapt most code to use ScanpathFixations instead of FixationTrains
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 11, 2024
1 parent ff50f41 commit a0c8f47
Show file tree
Hide file tree
Showing 18 changed files with 253 additions and 237 deletions.
42 changes: 21 additions & 21 deletions notebooks/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
],
"source": [
"rst = np.random.RandomState(seed=23)\n",
"scanpath_indices = rst.randint(len(fixations.train_xs), size=3)\n",
"scanpath_indices = rst.randint(len(fixations.scanpaths), size=3)\n",
"\n",
"for scanpath_index in scanpath_indices:\n",
" print(f\"Scanpath no {scanpath_index}:\")\n",
Expand Down Expand Up @@ -402,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"id": "d7a92d74-4512-42c7-b21f-e62e211df361",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 13,
"id": "f2085fb4-6a30-426a-8a47-0da876d92ef7",
"metadata": {},
"outputs": [
Expand All @@ -472,7 +472,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 14,
"id": "658a06b0-f339-41d7-9f74-bba4ee9bebb3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -519,7 +519,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 15,
"id": "a8068b3c-c1ca-4a64-adae-a7ac769e3a82",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -547,7 +547,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 16,
"id": "106515c4-3f4c-43f6-8d58-501045132c0f",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -593,7 +593,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 17,
"id": "27fbe076-72d3-4441-a095-69a455e2a210",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -634,7 +634,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 18,
"id": "b1b6dad7-e6a9-4a66-b924-8dbc3312ec94",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -667,7 +667,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 19,
"id": "b66e64ed-ff99-4932-98f8-35cabf60c26c",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -712,7 +712,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 20,
"id": "b20ad889-dbeb-41d6-b79d-1adb0b23150a",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -783,7 +783,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 21,
"id": "9696744b-6cf7-4c43-8fe8-93a6f1a0fa4d",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -842,7 +842,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 22,
"id": "3f010b8c-d368-4c5c-b241-1ebf88219b94",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -889,7 +889,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 23,
"id": "e74fe712-34fa-4201-8415-8542c0080ffc",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -922,7 +922,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 24,
"id": "38220bad-2512-465d-8d90-f4bfc7426667",
"metadata": {},
"outputs": [
Expand All @@ -935,7 +935,7 @@
" 1.67418705, 3.50254012, 3.13763833, 2.68713872, 1.0400994 ])"
]
},
"execution_count": 32,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -963,7 +963,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 25,
"id": "9e032691-527e-456e-8639-7f360e4582e3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -991,7 +991,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 26,
"id": "442b8e2e-24a6-44c0-9992-e37336317962",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -1037,7 +1037,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 27,
"id": "3255672f-e3df-4b1d-8cad-cb8e7a39fcd2",
"metadata": {},
"outputs": [],
Expand All @@ -1047,7 +1047,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 28,
"id": "f8c137e6-0fa0-4f19-bf92-75e93ed6bbab",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -1112,7 +1112,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 29,
"id": "c17d068c-96c2-410b-8a62-fc229ad45453",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -1141,7 +1141,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 30,
"id": "87c10abe-57a2-4bbc-937f-74488a2636e3",
"metadata": {},
"outputs": [
Expand Down
6 changes: 3 additions & 3 deletions pysaliency/baseline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numba
import numpy as np
from boltons.iterutils import chunked
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import gaussian_filter
from scipy.special import logsumexp
from sklearn.base import BaseEstimator, DensityMixin
from sklearn.model_selection import cross_val_score
Expand Down Expand Up @@ -137,7 +137,7 @@ def __iter__(self):
for n in range(len(self.stimuli)):
for s in range(self.fixations.subject_count):
image_inds = self.fixations.n == n
subject_inds = self.fixations.subjects == s
subject_inds = self.fixations.subject == s
train_inds, test_inds = image_inds & ~subject_inds, image_inds & subject_inds
if test_inds.sum() == 0 or train_inds.sum() == 0:
#print("Skipping")
Expand All @@ -152,7 +152,7 @@ def __iter__(self):
yield train_inds, test_inds

def __len__(self):
return len(set(zip(self.fixations.n, self.fixations.subjects)))
return len(set(zip(self.fixations.n, self.fixations.subject)))


class ScikitLearnWithinImageCrossValidationGenerator(object):
Expand Down
2 changes: 1 addition & 1 deletion pysaliency/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def create_subset(stimuli, fixations, stimuli_indices):

new_image_indices = [new_pos[i] for i in fixations.scanpaths.n[scanpath_inds]]

new_fixations = fixations.filter_fixation_trains(scanpath_inds)
new_fixations = fixations.filter_scanpaths(scanpath_inds)
new_fixations.scanpaths.n = np.array(new_image_indices)

new_fixation_ns = [new_pos[i] for i in new_fixations.n]
Expand Down
12 changes: 9 additions & 3 deletions pysaliency/datasets/fixations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from tqdm import tqdm

from ..utils import remove_trailing_nans
from ..utils import remove_trailing_nans, deprecated_class
from ..utils.variable_length_array import VariableLengthArray
from .scanpaths import Scanpaths
from .utils import _load_attribute_dict_from_hdf5, concatenate_attributes, decode_string, get_merged_attribute_list, hdf5_wrapper
Expand Down Expand Up @@ -315,7 +315,10 @@ def to_hdf5(self, target):
variable_length_arrays = []

for attribute in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n', 'lengths'] + self.__attributes__:
data = getattr(self, attribute)
if attribute == 'lengths':
data = self.scanpath_history_length
else:
data = getattr(self, attribute)
if isinstance(data, VariableLengthArray):
variable_length_arrays.append(attribute)
data = data._data
Expand Down Expand Up @@ -528,6 +531,7 @@ def read_hdf5(cls, source):
return cls(scanpaths=scanpaths)


@deprecated_class(deprecated_in="0.3.0", removed_in="1.0.0", details="Use `ScanpathFixations` instead")
class FixationTrains(ScanpathFixations):
"""
Capsules the fixations of a dataset as fixation trains.
Expand All @@ -552,6 +556,8 @@ class FixationTrains(ScanpathFixations):
"""
def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanpath_attributes=None, scanpath_fixation_attributes=None, attributes=None, scanpath_attribute_mapping=None, scanpaths=None):

# raise ValueError("DON'T USE FIXATIONTRAINS ANYMORE, USE SCANPATHFIXATIONS INSTEAD")

if isinstance(train_xs, Scanpaths):
scanpaths = train_xs

Expand Down Expand Up @@ -580,7 +586,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
attributes = {}

if attributes:
warnings.warn("don't use attributes for FixationTrains, use scanpath_attributes or scanpath_fixation_attributes instead!", stacklevel=2)
warnings.warn("Don't use attributes for FixationTrains, use scanpath_attributes or scanpath_fixation_attributes instead! FixationTrains is deprecated, the successor ScanpathFixations doesn't support attributes anymore", stacklevel=2, category=DeprecationWarning)

if attributes:
self.__attributes__ = list(self.__attributes__)
Expand Down
3 changes: 2 additions & 1 deletion pysaliency/datasets/scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
length=None,
scanpath_attributes: Optional[Dict[str, np.ndarray]] = None,
fixation_attributes: Optional[Dict[str, Union[np.ndarray, VariableLengthArray]]]=None,
attribute_mapping=Dict[str, str],
attribute_mapping: Optional[Dict[str, str]] = None,
**kwargs):

self.n = np.asarray(n)
Expand All @@ -57,6 +57,7 @@ def __init__(self,
scanpath_attributes = scanpath_attributes or {}
fixation_attributes = fixation_attributes or {}
self.attribute_mapping = attribute_mapping or {}
self.attribute_mapping = dict(self.attribute_mapping)


for key, value in kwargs.items():
Expand Down
2 changes: 1 addition & 1 deletion pysaliency/external_datasets/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def check_size(f):
subject_path = os.path.join('DATA', subject)
outfile = '{0}_{1}.mat'.format(stimulus, subject)
outfile = os.path.join(out_path, outfile)
cmds.append("fprintf('%d/%d\\r', {}, {});".format(n * len(subjects) + subject_id, total_cmd_count))
cmds.append("fprintf('Processing scanpath %d/%d\\r', {}, {});".format(n * len(subjects) + subject_id, total_cmd_count))
cmds.append("extract_fixations('{0}', '{1}', '{2}');".format(stimulus, subject_path, outfile))

print('Running original code to extract fixations. This can take some minutes.')
Expand Down
20 changes: 10 additions & 10 deletions pysaliency/filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from boltons.iterutils import chunked

from .datasets import create_subset, FixationTrains, Fixations, Stimuli
from .datasets import create_subset, FixationTrains, Fixations, Stimuli, ScanpathFixations


def train_split(stimuli, fixations, crossval_folds, fold_no, val_folds=1, test_folds=1, random=True, stratified_attributes=None):
Expand Down Expand Up @@ -234,17 +234,17 @@ def filter_stimuli_by_size(stimuli, fixations, size=None, sizes=None):
return create_subset(stimuli, fixations, indices)


def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False):
"""Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)"""
def filter_scanpaths_by_attribute(scanpaths: ScanpathFixations, attribute_name, attribute_value, invert_match=False):
"""Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpaths.scanpath_attributes)"""

mask = scanpaths.scanpath_attributes[attribute_name] == attribute_value
mask = scanpaths.scanpaths.scanpath_attributes[attribute_name] == attribute_value
if mask.ndim > 1:
mask = np.all(mask, axis=1)

if invert_match is True:
mask = ~mask

return scanpaths.filter_fixation_trains(mask)
return scanpaths.filter_scanpaths(mask)


def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False):
Expand Down Expand Up @@ -280,20 +280,20 @@ def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribut
return create_subset(stimuli, fixations, indices)


def filter_scanpaths_by_length(scanpaths: FixationTrains, intervals: list):
def filter_scanpaths_by_length(scanpath_fixations: ScanpathFixations, intervals: list):
"""Filter Scanpaths by number of fixations"""

intervals = _check_intervals(intervals, type=int)
mask = np.zeros(len(scanpaths.train_lengths), dtype=bool)
mask = np.zeros(len(scanpath_fixations.scanpaths), dtype=bool)
for start, end in intervals:
temp_mask = np.logical_and(
scanpaths.train_lengths >= start, scanpaths.train_lengths < end)
scanpath_fixations.scanpaths.length >= start, scanpath_fixations.scanpaths.length < end)
mask = np.logical_or(mask, temp_mask)
indices = list(np.nonzero(mask)[0])

scanpaths = scanpaths.filter_fixation_trains(indices)
scanpath_fixations = scanpath_fixations.filter_scanpaths(indices)

return scanpaths
return scanpath_fixations


def remove_stimuli_without_fixations(stimuli: Stimuli, fixations: Fixations):
Expand Down
6 changes: 3 additions & 3 deletions pysaliency/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DisjointUnionMixin,
GaussianSaliencyMapModel,
)
from .datasets import FixationTrains, check_prediction_shape, get_image_hash, as_stimulus
from .datasets import Scanpaths, ScanpathFixations, check_prediction_shape, get_image_hash, as_stimulus
from .metrics import probabilistic_image_based_kl_divergence, convert_saliency_map_to_density
from .sampling_models import SamplingModelMixin
from .utils import Cache, average_values, deprecated_class, remove_trailing_nans, iterator_chunks
Expand Down Expand Up @@ -212,7 +212,7 @@ def _expand_sample_arguments(self, stimuli, train_counts, lengths=None, stimulus

return stimuli, train_counts, lengths, stimulus_indices

def sample(self, stimuli, train_counts, lengths=1, stimulus_indices=None, rst=None, verbose=False):
def sample(self, stimuli, train_counts, lengths=1, stimulus_indices=None, rst=None, verbose=False) -> ScanpathFixations:
"""
Sample fixations for given stimuli
Expand Down Expand Up @@ -261,7 +261,7 @@ def sample(self, stimuli, train_counts, lengths=1, stimulus_indices=None, rst=No
ns.append(stimulus_index)
subjects.append(0)
pbar.update(1)
return FixationTrains.from_fixation_trains(xs, ys, ts, ns, subjects)
return ScanpathFixations(Scanpaths(xs=xs, ys=ys, ts=ts, n=ns, subject=subjects))

def _sample_fixation_train(self, stimulus, length, rst=None):
"""Sample one fixation train of given length from stimulus"""
Expand Down
6 changes: 3 additions & 3 deletions pysaliency/saliency_map_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,12 +956,12 @@ def __init__(self, subject_models, **kwargs):

def _split_fixations(self, stimuli, fixations):
for s in self.subject_models:
yield fixations.subjects == s, self.subject_models[s]
yield fixations.subject == s, self.subject_models[s]

def conditional_saliency_map(self, stimulus, x_hist, y_hist, t_hist, attributes=None, out=None, **kwargs):
if 'subjects' not in attributes:
if 'subject' not in attributes:
raise ValueError("SubjectDependentSaliencyModel can't compute conditional saliency maps without subject indication!")
return self.subject_models[attributes['subjects']].conditional_saliency_map(
return self.subject_models[attributes['subject']].conditional_saliency_map(
stimulus, x_hist, y_hist, t_hist, attributes=attributes, **kwargs)


Expand Down
Loading

0 comments on commit a0c8f47

Please sign in to comment.