diff --git a/pysaliency/datasets.py b/pysaliency/datasets.py index 7226d81..eba041b 100644 --- a/pysaliency/datasets.py +++ b/pysaliency/datasets.py @@ -1,19 +1,18 @@ -# vim: set expandtab : -#kate: space-indent on; indent-width 4; backspace-indents on; -from __future__ import absolute_import, print_function, division, unicode_literals - -from collections.abc import Sequence -from functools import wraps -from hashlib import sha1 import json import os import pathlib -from typing import Union import warnings +from collections.abc import Sequence +from functools import wraps +from hashlib import sha1 +from typing import Dict, Optional, Union from weakref import WeakValueDictionary -from boltons.cacheutils import cached import numpy as np +from boltons.cacheutils import cached + +from .utils.variable_length_array import VariableLengthArray + try: from imageio.v3 import imread except ImportError: @@ -21,7 +20,7 @@ from PIL import Image from tqdm import tqdm -from .utils import LazyList, build_padded_2d_array, remove_trailing_nans +from .utils import LazyList, remove_trailing_nans def hdf5_wrapper(mode=None): @@ -76,6 +75,8 @@ def read_hdf5(source): return Fixations.read_hdf5(source) elif data_type == 'FixationTrains': return FixationTrains.read_hdf5(source) + elif data_type == 'Scanpaths': + return Scanpaths.read_hdf5(source) elif data_type == 'Stimuli': return Stimuli.read_hdf5(source) elif data_type == 'FileStimuli': @@ -784,65 +785,6 @@ def generate_crossval(self, splitcount = 10): train_subjects_eval) return fixations_training, fixations_evaluation -# def generate_nonfixations(self, seed=42): -# """Generate nonfixational distribution from this -# fixation object by shuffling the images of the -# fixation trains. The individual fixation trains -# will be left intact""" -# train_xs = self.train_xs.copy() -# train_ys = self.train_ys.copy() -# train_ts = self.train_ts.copy() -# train_ns = self.train_ns.copy() -# train_subjects = self.train_subjects.copy() -# max_n = train_ns.max() -# rs = np.random.RandomState(seed) -# for i in range(len(train_ns)): -# old_n = train_ns[i] -# new_ns = range(0, old_n)+range(old_n+1, max_n+1) -# new_n = rs.choice(new_ns) -# train_ns[i] = new_n -# return type(self)(train_xs, train_ys, train_ts, train_ns, train_subjects) -# -# def generate_more_nonfixations(self, count=1, seed=42): -# """Generate nonfixational distribution from this -# fixation object by assining each fixation -# train to $count other images. -# -# with count=0, each train will be assigned to all -# other images""" -# train_xs = [] -# train_ys = [] -# train_ts = [] -# train_ns = [] -# train_subjects = [] -# max_n = self.train_ns.max() -# if count == 0: -# count = max_n-1 -# rs = np.random.RandomState(seed) -# for i in range(len(self.train_ns)): -# old_n = self.train_ns[i] -# new_ns = range(0, old_n)+range(old_n+1, max_n+1) -# new_ns = rs.choice(new_ns, size=count, replace=False) -# for new_n in new_ns: -# train_xs.append(self.train_xs[i]) -# train_ys.append(self.train_ys[i]) -# train_ts.append(self.train_ts[i]) -# train_ns.append(new_n) -# train_subjects.append(self.train_subjects[i]) -# train_xs = np.vstack(train_xs) -# train_ys = np.vstack(train_ys) -# train_ts = np.vstack(train_ts) -# train_ns = np.hstack(train_ns) -# train_subjects = np.hstack(train_subjects) -# # reorder -# inds = np.argsort(train_ns) -# train_xs = train_xs[inds] -# train_ys = train_ys[inds] -# train_ts = train_ts[inds] -# train_ns = train_ns[inds] -# train_subjects = train_subjects[inds] -# return type(self)(train_xs, train_ys, train_ts, train_ns, train_subjects) - def shuffle_fixations(self, stimuli=None): new_indices = [] new_ns = [] @@ -1035,6 +977,166 @@ def read_hdf5(cls, source): return fixations +class Scanpaths(object): + """ + Represents a collection of scanpaths. + + Attributes: + xs (VariableLengthArray): The x-coordinates of the scanpaths. + ys (VariableLengthArray): The y-coordinates of the scanpaths. + ns (np.ndarray): The number of fixations in each scanpath. + lengths (np.ndarray): The lengths of each scanpath. + scanpath_attributes (dict): Additional attributes associated with the scanpaths. + fixation_attributes (dict): Additional attributes associated with the fixations in the scanpaths. + attribute_mapping (dict): Mapping of attribute names to their corresponding values, will be used when creating `Fixations` instances from the `Scanpaths` instance. + for example {'durations': 'duration'} + """ + + xs: VariableLengthArray + ys: VariableLengthArray + ns: np.ndarray + + def __init__(self, + xs: Union[np.ndarray, VariableLengthArray], + ys: Union[np.ndarray, VariableLengthArray], + ns: np.ndarray, + lengths=None, + scanpath_attributes: Optional[Dict[str, np.ndarray]] = None, + fixation_attributes: Optional[Dict[str, Union[np.ndarray, VariableLengthArray]]]=None, + attribute_mapping=Dict[str, str]): + + self.ns = np.asarray(ns) + + if not isinstance(xs, VariableLengthArray): + self.xs = VariableLengthArray(xs, lengths) + else: + self.xs = xs + + if lengths is not None: + if not np.all(self.xs.lengths == lengths): + raise ValueError("Lengths of xs and lengths do not match") + + self.lengths = self.xs.lengths.copy() + + self.ys = self._as_variable_length_array(ys) + + if not len(self.xs) == len(self.ys) == len(self.ns): + raise ValueError("Length of xs, ys, ts and ns has to match") + + # setting scanpath attributes + + scanpath_attributes = scanpath_attributes or {} + self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()} + + for key, value in self.scanpath_attributes.items(): + if not len(value) == len(self.xs): + raise ValueError(f"Length of scanpath attribute {key} has to match number of scanpaths, but got {len(value)} != {len(self.xs)}") + + # setting fixation attributes + + fixation_attributes = fixation_attributes or {} + + self.fixation_attributes = {key: self._as_variable_length_array(value) for key, value in fixation_attributes.items()} + + self.attribute_mapping = attribute_mapping or {} + + def _check_lengths(self, other: VariableLengthArray): + if not len(self) == len(other): + raise ValueError("Length of scanpaths has to match") + if not np.all(self.lengths == other.lengths): + raise ValueError("Lengths of scanpaths have to match") + + def _as_variable_length_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray: + if not isinstance(data, VariableLengthArray): + data = VariableLengthArray(data, self.lengths) + + self._check_lengths(data) + + return data + + def __len__(self): + return len(self.xs) + + @hdf5_wrapper(mode='w') + def to_hdf5(self, target): + """ Write scanpaths to hdf5 file or hdf5 group + """ + target.attrs['type'] = np.string_('Scanpaths') + target.attrs['version'] = np.string_('1.0') + + target.create_dataset('xs', data=self.xs._data) + target.create_dataset('ys', data=self.ys._data) + target.create_dataset('ns', data=self.ns) + target.create_dataset('lengths', data=self.lengths) + + scanpath_attributes_group = target.create_group('scanpath_attributes') + for attribute_name, attribute_value in self.scanpath_attributes.items(): + create_hdf5_dataset(scanpath_attributes_group, attribute_name, attribute_value) + scanpath_attributes_group.attrs['__attributes__'] = np.string_(json.dumps(sorted(self.scanpath_attributes.keys()))) + + fixation_attributes_group = target.create_group('fixation_attributes') + for attribute_name, attribute_value in self.fixation_attributes.items(): + fixation_attributes_group.create_dataset(attribute_name, data=attribute_value._data) + fixation_attributes_group.attrs['__attributes__'] = np.string_(json.dumps(sorted(self.fixation_attributes.keys()))) + + target.attrs['attribute_mapping'] = np.string_(json.dumps(self.attribute_mapping)) + + + @classmethod + @hdf5_wrapper(mode='r') + def read_hdf5(cls, source): + data_type = decode_string(source.attrs['type']) + data_version = decode_string(source.attrs['version']) + + if data_type != 'Scanpaths': + raise ValueError("Invalid type! Expected 'Scanpaths', got", data_type) + + valid_versions = ['1.0'] + if data_version not in valid_versions: + raise ValueError("Invalid version! Expected one of {}, got {}".format(', '.join(valid_versions), data_version)) + + lengths = source['lengths'][...] + xs = VariableLengthArray(source['xs'][...], lengths) + ys = VariableLengthArray(source['ys'][...], lengths) + ns = source['ns'][...] + + scanpath_attributes = _load_attribute_dict_from_hdf5(source['scanpath_attributes']) + + fixation_attributes_group = source['fixation_attributes'] + json_attributes = fixation_attributes_group.attrs['__attributes__'] + if not isinstance(json_attributes, str): + json_attributes = json_attributes.decode('utf8') + __attributes__ = json.loads(json_attributes) + + fixation_attributes = {attribute: VariableLengthArray(fixation_attributes_group[attribute][...], lengths) for attribute in __attributes__} + + return cls( + xs=xs, + ys=ys, + ns=ns, + lengths=lengths, + scanpath_attributes=scanpath_attributes, + fixation_attributes=fixation_attributes, + attribute_mapping=json.loads(decode_string(source.attrs['attribute_mapping'])) + ) + + def __getitem__(self, index): + # TODO + # - integer to return single scanpath + # - 2d index to return single Fixation (for now via index of scanpath and index of fixation in scanpath) + # - 2d index array to return Fixations instance (for now via index of scanpath and index of fixation in scanpath) + + if isinstance(index, tuple): + raise NotImplementedError("Not implemented yet") + elif isinstance(index, int): + raise NotImplementedError("Not implemented yet") + else: + return type(self)(self.xs[index], self.ys[index], self.ns[index], self.lengths[index], + scanpath_attributes={key: value[index] for key, value in self.scanpath_attributes.items()}, + fixation_attributes={key: value[index] for key, value in self.fixation_attributes.items()}, + attribute_mapping=self.attribute_mapping) + + def get_image_hash(img): """ Calculate a unique hash for the given image. diff --git a/tests/datasets/test_scanpaths.py b/tests/datasets/test_scanpaths.py new file mode 100644 index 0000000..30e422f --- /dev/null +++ b/tests/datasets/test_scanpaths.py @@ -0,0 +1,219 @@ +from copy import deepcopy + +import numpy as np +import pytest + +import pysaliency +from pysaliency.datasets import Scanpaths +from pysaliency.utils.variable_length_array import VariableLengthArray + + +def assert_variable_length_array_equal(array1, array2): + assert len(array1) == len(array2) + + for i in range(len(array1)): + np.testing.assert_array_equal(array1[i], array2[i], err_msg=f'arrays not equal at index {i}') + + +def assert_scanpaths_equal(scanpaths1: Scanpaths, scanpaths2: Scanpaths, scanpaths2_inds=None): + + if scanpaths2_inds is None: + scanpaths2_inds = slice(None) + + assert isinstance(scanpaths1, Scanpaths) + assert isinstance(scanpaths2, Scanpaths) + + assert_variable_length_array_equal(scanpaths1.xs, scanpaths2.xs[scanpaths2_inds]) + assert_variable_length_array_equal(scanpaths1.ys, scanpaths2.ys[scanpaths2_inds]) + + assert scanpaths1.scanpath_attributes.keys() == scanpaths2.scanpath_attributes.keys() + for attribute_name in scanpaths1.scanpath_attributes.keys(): + np.testing.assert_array_equal(scanpaths1.scanpath_attributes[attribute_name], scanpaths2.scanpath_attributes[attribute_name][scanpaths2_inds]) + + assert scanpaths1.fixation_attributes.keys() == scanpaths2.fixation_attributes.keys() + for attribute_name in scanpaths1.fixation_attributes.keys(): + assert_variable_length_array_equal(scanpaths1.fixation_attributes[attribute_name], scanpaths2.fixation_attributes[attribute_name][scanpaths2_inds]) + + assert scanpaths1.attribute_mapping == scanpaths2.attribute_mapping + + +def test_scanpaths(): + xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]) + ys = np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]]) + ns = np.array([0, 0, 1]) + lengths = np.array([3, 2, 3]) + scanpath_attributes = {'task': np.array([0, 1, 0])} + fixation_attributes = {'attribute1': np.array([[1, 1, 2], [2, 2, np.nan], [0, 1, 3]]), 'attribute2': np.array([[3, 1.3, 5], [1, 42, np.nan], [0, -1, -3]])} + attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + scanpaths = Scanpaths(xs, ys, ns, lengths, scanpath_attributes, fixation_attributes, attribute_mapping) + + assert isinstance(scanpaths.xs, VariableLengthArray) + assert isinstance(scanpaths.ys, VariableLengthArray) + assert isinstance(scanpaths.ns, np.ndarray) + assert isinstance(scanpaths.lengths, np.ndarray) + assert isinstance(scanpaths.scanpath_attributes, dict) + assert isinstance(scanpaths.scanpath_attributes['task'], np.ndarray) + assert isinstance(scanpaths.fixation_attributes, dict) + assert isinstance(scanpaths.fixation_attributes['attribute1'], VariableLengthArray) + assert isinstance(scanpaths.fixation_attributes['attribute2'], VariableLengthArray) + assert isinstance(scanpaths.attribute_mapping, dict) + + np.testing.assert_array_equal(scanpaths.xs._data, xs) + np.testing.assert_array_equal(scanpaths.ys._data, ys) + np.testing.assert_array_equal(scanpaths.ns, ns) + np.testing.assert_array_equal(scanpaths.lengths, lengths) + np.testing.assert_array_equal(scanpaths.scanpath_attributes['task'], np.array([0, 1, 0])) + np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute1']._data, np.array([[1, 1, 2], [2, 2, np.nan], [0, 1, 3]])) + np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute2']._data, np.array([[3, 1.3, 5], [1, 42, np.nan], [0, -1, -3]])) + assert scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2'} + + +def test_scanpaths_from_lists(): + xs = [[0, 1, 2], [2, 2], [1, 5, 3]] + ys = [[10, 11, 12], [12, 12], [21, 25, 33]] + ns = [0, 0, 1] + expected_lengths = np.array([3, 2, 3]) + scanpath_attributes = {'task': [0, 1, 0]} + fixation_attributes = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]} + attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + scanpaths = Scanpaths(xs, ys, ns, lengths=None, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=attribute_mapping) + + np.testing.assert_array_equal(scanpaths.xs._data, np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]])) + np.testing.assert_array_equal(scanpaths.ys._data, np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]])) + np.testing.assert_array_equal(scanpaths.ns, ns) + np.testing.assert_array_equal(scanpaths.lengths, expected_lengths) + np.testing.assert_array_equal(scanpaths.scanpath_attributes['task'], np.array([0, 1, 0])) + np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute1']._data, np.array([[1, 1, 2], [2, 2, np.nan], [0, 1, 3]])) + np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute2']._data, np.array([[3, 1.3, 5], [1, 42, np.nan], [0, -1, -3]])) + assert scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2'} + + +def test_scanpaths_init_inconsistent_lengths(): + xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]) + ys = np.array([[10, 11, 12], [12, 12, np.nan]]) # too short, should fail + ns = np.array([0, 0, 1]) + lengths = np.array([3, 2, 3]) + scanpath_attributes = {'task': np.array([0, 1, 0])} + fixation_attributes = {'attribute1': np.array([[1, 1, 2], [2, 2, np.nan], [0, 1, 3]]), 'attribute2': np.array([[3, 1.3, 5], [1, 42, np.nan], [0, -1, -3]])} + attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + with pytest.raises(ValueError): + Scanpaths(xs, ys, ns, lengths, scanpath_attributes, fixation_attributes, attribute_mapping) + +def test_scanpaths_init_invalid_scanpath_attributes(): + xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]) + ys = np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]]) + ns = np.array([0, 0, 1]) + lengths = np.array([3, 2, 3]) + scanpath_attributes = {'invalid_attribute': np.array([1, 2]), 'attribute2': np.array([4, 5, 6])} # Invalid attribute length + scanpath_fixation_attributes = {'fixation_attribute1': np.array([[1, 2], [3, 4], [5, 6]]), 'fixation_attribute2': np.array([[7, 8], [9, 10], [11, 12]])} + scanpath_attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + with pytest.raises(ValueError): + Scanpaths(xs, ys, ns, lengths, scanpath_attributes, scanpath_fixation_attributes, scanpath_attribute_mapping) + +def test_scanpaths_init_invalid_scanpath_fixation_attributes(): + xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]) + ys = np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]]) + ns = np.array([0, 0, 1]) + lengths = np.array([3, 2, 3]) + scanpath_attributes = {'attribute1': np.array([1, 2, 3]), 'attribute2': np.array([4, 5, 6])} + scanpath_fixation_attributes = {'valid_fixation_attribute': np.array([[1, 2], [3, 4], [5, 6]]), 'invalid_fixation_attribute': np.array([[7, 8], [9, 10]])} # Invalid fixation attribute length + scanpath_attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + with pytest.raises(ValueError): + Scanpaths(xs, ys, ns, lengths, scanpath_attributes, scanpath_fixation_attributes, scanpath_attribute_mapping) + +def test_scanpaths_init_invalid_scanpath_fixation_attributes_dimensions(): + xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]) + ys = np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]]) + ns = np.array([0, 0, 1]) + lengths = np.array([3, 2, 3]) + scanpath_attributes = {'attribute1': np.array([1, 2, 3]), 'attribute2': np.array([4, 5, 6])} + scanpath_fixation_attributes = {'fixation_attribute1': np.array([1, 2, 3]), 'fixation_attribute2': np.array([[7, 8], [9, 10], [11, 12]])} # Invalid fixation attribute dimensions + scanpath_attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + with pytest.raises(ValueError): + Scanpaths(xs, ys, ns, lengths, scanpath_attributes, scanpath_fixation_attributes, scanpath_attribute_mapping) + +def test_scanpaths_init_invalid_scanpath_lengths(): + + data = { + 'xs': [[0, 1, 2], [2, 2], [1, 5, 3]], + 'ys': [[10, 11, 12], [12, 12], [21, 25, 33]], + 'ns': [0, 0, 1], + 'scanpath_attributes': {'task': [0, 1, 0]}, + 'fixation_attributes': {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}, + 'attribute_mapping': {'attribute1': 'attr1', 'attribute2': 'attr2'}, + } + + # make sure original data works + Scanpaths(**data) + + + for scanpath_attribute in ['xs', 'ys']: + data_copy = deepcopy(data) + data_copy[scanpath_attribute][-1].append(4) + with pytest.raises(ValueError): + Scanpaths(**data_copy) + + for scanpath_attribute in data['fixation_attributes'].keys(): + data_copy = deepcopy(data) + data_copy['fixation_attributes'][scanpath_attribute][-1].append(4) + with pytest.raises(ValueError): + Scanpaths(**data_copy) + + + +@pytest.mark.parametrize('inds', [ + slice(None, 2), + slice(1, None), + [0, 1], + [1, 2], + [0, 2], + [2, 1], + [False, True, True], +]) +def test_scanpaths_slicing(inds): + xs = [[0, 1, 2], [2, 2], [1, 5, 3]] + ys = [[10, 11, 12], [12, 12], [21, 25, 33]] + ns = [0, 0, 1] + scanpath_attributes = {'task': [0, 1, 0]} + fixation_attributes = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]} + attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + scanpaths = Scanpaths(xs, ys, ns, lengths=None, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=attribute_mapping) + + sliced_scanpaths = scanpaths[inds] + assert_scanpaths_equal(sliced_scanpaths, scanpaths, inds) + + +def test_write_read_scanpaths_pathlib(tmp_path): + filename = tmp_path / 'scanpaths.hdf5' + + xs = [[0, 1, 2], [2, 2], [1, 5, 3]] + ys = [[10, 11, 12], [12, 12], [21, 25, 33]] + ns = [0, 0, 1] + scanpath_attributes = {'task': [0, 1, 0]} + fixation_attributes = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]} + attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'} + + scanpaths = Scanpaths(xs, ys, ns, lengths=None, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=attribute_mapping) + + scanpaths.to_hdf5(filename) + + # test loading via class method + + new_scanpaths = Scanpaths.read_hdf5(filename) + + assert scanpaths is not new_scanpaths # make sure there is no sophisticated caching... + assert_scanpaths_equal(scanpaths, new_scanpaths) + + # test loading via pysaliency + + new_scanpaths = pysaliency.read_hdf5(filename) + + assert scanpaths is not new_scanpaths # make sure there is no sophisticated caching... + assert_scanpaths_equal(scanpaths, new_scanpaths) \ No newline at end of file diff --git a/tests/test_dataset_config.py b/tests/test_dataset_config.py index 05eb823..6e3fe33 100644 --- a/tests/test_dataset_config.py +++ b/tests/test_dataset_config.py @@ -5,7 +5,7 @@ import numpy as np import pytest from imageio import imwrite -from test_datasets import compare_fixations, compare_scanpaths +from test_datasets import compare_fixations, compare_fixation_trains from test_filter_datasets import assert_stimuli_equal import pysaliency @@ -129,7 +129,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_task(stimuli, filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -146,7 +146,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_multi_dim_att filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -195,7 +195,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_length_multiple_inputs( filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -210,5 +210,5 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_length_single_input(sti filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 610f55a..6c12b73 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +from copy import deepcopy import os.path import pickle import unittest @@ -13,7 +14,7 @@ from test_helpers import TestWithData import pysaliency -from pysaliency.datasets import Fixations, FixationTrains, Stimulus, check_prediction_shape, scanpaths_from_fixations +from pysaliency.datasets import Fixations, FixationTrains, Scanpaths, Stimulus, check_prediction_shape, scanpaths_from_fixations def compare_fixations_subset(f1, f2, f2_inds): @@ -56,7 +57,7 @@ def compare_fixations(f1, f2, crop_length=False): np.testing.assert_array_equal(attribute1, attribute2, err_msg=f'attributes not equal: {attribute}') -def compare_scanpaths(scanpaths1, scanpaths2): +def compare_fixation_trains(scanpaths1, scanpaths2): np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs) np.testing.assert_array_equal(scanpaths1.train_ys, scanpaths2.train_ys) np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs) @@ -77,7 +78,6 @@ def compare_scanpaths(scanpaths1, scanpaths2): compare_fixations(scanpaths1, scanpaths2) - class TestFixations(TestWithData): def test_from_fixations(self): xs_trains = [ @@ -403,7 +403,7 @@ def fixation_trains(): def test_copy_scanpaths(fixation_trains): copied_fixation_trains = fixation_trains.copy() - compare_scanpaths(copied_fixation_trains, fixation_trains) + compare_fixation_trains(copied_fixation_trains, fixation_trains) def test_copy_fixations(fixation_trains): @@ -420,7 +420,7 @@ def test_write_read_scanpaths_pathlib(tmp_path, fixation_trains): # make sure there is no sophisticated caching... assert fixation_trains is not new_fixation_trains - compare_scanpaths(fixation_trains, new_fixation_trains) + compare_fixation_trains(fixation_trains, new_fixation_trains) def test_write_read_scanpaths(tmp_path, fixation_trains): @@ -431,7 +431,7 @@ def test_write_read_scanpaths(tmp_path, fixation_trains): # make sure there is no sophisticated caching... assert fixation_trains is not new_fixation_trains - compare_scanpaths(fixation_trains, new_fixation_trains) + compare_fixation_trains(fixation_trains, new_fixation_trains) def test_scanpath_lengths(fixation_trains): @@ -815,5 +815,7 @@ def test_check_prediction_shape(): check_prediction_shape(prediction, stimulus) assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)" + + if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/test_filter_datasets.py b/tests/test_filter_datasets.py index 40234cb..f45649a 100644 --- a/tests/test_filter_datasets.py +++ b/tests/test_filter_datasets.py @@ -7,7 +7,7 @@ import pysaliency import pysaliency.filter_datasets as filter_datasets from pysaliency.filter_datasets import filter_fixations_by_attribute, filter_stimuli_by_attribute, filter_scanpaths_by_attribute, filter_scanpaths_by_length, create_subset, remove_stimuli_without_fixations -from test_datasets import compare_fixations, compare_scanpaths +from test_datasets import compare_fixations, compare_fixation_trains @pytest.fixture @@ -418,7 +418,7 @@ def test_filter_scanpaths_by_attribute_task(fixation_trains): filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): @@ -429,7 +429,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation_trains): @@ -440,7 +440,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) @pytest.mark.parametrize('intervals', [([(1, 2), (2, 3)]), ([(2, 3), (3, 4)]), ([(2)]), ([(3)])]) @@ -450,19 +450,19 @@ def test_filter_scanpaths_by_length(fixation_trains, intervals): if intervals == [(1, 2), (2, 3)]: inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) if intervals == [(2, 3), (3, 4)]: inds = [0, 1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) if intervals == [(2)]: inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) if intervals == [(3)]: inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_scanpaths(filtered_scanpaths, expected_scanpaths) + compare_fixation_trains(filtered_scanpaths, expected_scanpaths) def test_remove_stimuli_without_fixations(file_stimuli_with_attributes, fixation_trains):