From 6fb8c0614590ead1dc20d9cd4641052e1bd50486 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Fri, 29 Mar 2024 01:11:34 +0100 Subject: [PATCH] Adapt FixationTrain attributes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/datasets.py | 211 ++++++++++++++++++++++++++-------- pysaliency/torch_datasets.py | 5 +- tests/test_dataset_config.py | 14 +-- tests/test_datasets.py | 130 +++++++++++++++------ tests/test_filter_datasets.py | 30 ++--- 5 files changed, 279 insertions(+), 111 deletions(-) diff --git a/pysaliency/datasets.py b/pysaliency/datasets.py index eba041b..89c095a 100644 --- a/pysaliency/datasets.py +++ b/pysaliency/datasets.py @@ -5,13 +5,13 @@ from collections.abc import Sequence from functools import wraps from hashlib import sha1 -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from weakref import WeakValueDictionary import numpy as np from boltons.cacheutils import cached -from .utils.variable_length_array import VariableLengthArray +from .utils.variable_length_array import VariableLengthArray, concatenate_variable_length_arrays try: from imageio.v3 import imread @@ -138,33 +138,75 @@ class Fixations(object): """ __attributes__ = ['subjects'] - def __init__(self, x, y, t, x_hist, y_hist, t_hist, n, subjects, attributes=None): - x = np.asarray(x) - y = np.asarray(y) - t = np.asarray(t) - n = np.asarray(n) - x_hist = np.asarray(x_hist) - y_hist = np.asarray(y_hist) - t_hist = np.asarray(t_hist) - subjects = np.asarray(subjects) - - self.x = x - self.y = y - self.t = t + def __init__(self, + x: Union[List, np.ndarray], + y: Union[List, np.ndarray], + t: Union[List, np.ndarray], + x_hist: Union[List, VariableLengthArray], + y_hist: Union[List, VariableLengthArray], + t_hist: Union[List, VariableLengthArray], + n: Union[List, np.ndarray], + subjects: Optional[Union[List, np.ndarray]] = None, + attributes: Optional[Dict[str, Union[np.ndarray, VariableLengthArray]]] = None): + + self.x = np.asarray(x) + self.y = np.asarray(y) + self.t = np.asarray(t) + self.n = np.asarray(n) + + # would be nice, is not yet supported. But we can simply pass the VariableLengthArray instead + # if isinstance(x_hist, list): + # x_hist = VariableLengthArray(x_hist) + # self.lengths = x_hist.lengths + if isinstance(x_hist, (list, np.ndarray)): + x_hist = np.array(x_hist) + self.lengths = (1 - np.isnan(x_hist)).sum(axis=-1) + x_hist = VariableLengthArray(x_hist, lengths=self.lengths) + elif isinstance(x_hist, VariableLengthArray): + self.lengths = x_hist.lengths + + + y_hist = self._as_variable_length_array(y_hist) + t_hist = self._as_variable_length_array(t_hist) + + if subjects is not None: + subjects = np.asarray(subjects) + self.x_hist = x_hist self.y_hist = y_hist self.t_hist = t_hist self.n = n self.subjects = subjects - self.lengths = (1 - np.isnan(self.x_hist)).sum(axis=-1) + + if not len(self.x) == len(self.y) == len(self.t) == len(self.x_hist) == len(self.y_hist) == len(self.t_hist) == len(self.n): + raise ValueError("Lengths of fixations have to match") + if self.subjects is not None and not len(self.x) == len(self.subjects): + raise ValueError("Length of subjects has to match number of fixations") if attributes is not None: self.__attributes__ = list(self.__attributes__) for name, value in attributes.items(): if name not in self.__attributes__: self.__attributes__.append(name) + if not len(value) == len(self.x): + raise ValueError(f"Length of attribute '{name}' has to match number of fixations") setattr(self, name, value) + + def _check_lengths(self, other: VariableLengthArray): + if not len(self) == len(other): + raise ValueError("Length of scanpaths has to match") + if not np.all(self.lengths == other.lengths): + raise ValueError("Lengths of scanpaths have to match") + + def _as_variable_length_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray: + if not isinstance(data, VariableLengthArray): + data = VariableLengthArray(data, self.lengths) + + self._check_lengths(data) + + return data + @classmethod def create_without_history(cls, x, y, n, subjects=None): """ Create new fixation object from fixation data without time and optionally @@ -256,6 +298,7 @@ def filter(self, inds): def filter_array(name): kwargs[name] = getattr(self, name)[inds].copy() + for name in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n']: filter_array(name) for name in self.__attributes__: @@ -328,7 +371,7 @@ def subject_count(self): def copy(self): cfix = Fixations(self.x.copy(), self.y.copy(), self.t.copy(), self.x_hist.copy(), self.y_hist.copy(), self.t_hist.copy(), - self.n.copy(), self.subjects.copy()) + self.n.copy(), self.subjects.copy() if self.subjects is not None else None) cfix.__attributes__ = list(self.__attributes__) for name in self.__attributes__: setattr(cfix, name, getattr(self, name).copy()) @@ -350,26 +393,34 @@ def to_hdf5(self, target): """ target.attrs['type'] = np.string_('Fixations') - target.attrs['version'] = np.string_('1.0') + target.attrs['version'] = np.string_('1.1') - for attribute in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n'] + self.__attributes__: - target.create_dataset(attribute, data=getattr(self, attribute)) + variable_length_arrays = [] + + for attribute in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n', 'lengths'] + self.__attributes__: + data = getattr(self, attribute) + if isinstance(data, VariableLengthArray): + variable_length_arrays.append(attribute) + data = data._data + target.create_dataset(attribute, data=data) - #target.create_dataset('__attributes__', data=self.__attributes__) target.attrs['__attributes__'] = np.string_(json.dumps(self.__attributes__)) + target.attrs['__variable_length_arrays__'] = np.string_(json.dumps(sorted(variable_length_arrays))) @classmethod @hdf5_wrapper(mode='r') def read_hdf5(cls, source): """ Read fixations from hdf5 file or hdf5 group """ + # TODO: rewrite to use constructor instead of manipulating the object directly + data_type = decode_string(source.attrs['type']) data_version = decode_string(source.attrs['version']) if data_type != 'Fixations': raise ValueError("Invalid type! Expected 'Fixations', got", data_type) - if data_version != '1.0': + if data_version not in ['1.0', '1.1']: raise ValueError("Invalid version! Expected '1.0', got", data_version) data = {key: source[key][...] for key in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n', 'subjects']} @@ -379,10 +430,27 @@ def read_hdf5(cls, source): if not isinstance(json_attributes, str): json_attributes = json_attributes.decode('utf8') __attributes__ = json.loads(json_attributes) - fixations.__attributes__ == list(__attributes__) + fixations.__attributes__ = list(__attributes__) + + if data_version == '1.1': + lengths = source['lengths'][...] + + json_variable_length_arrays = source.attrs['__variable_length_arrays__'] + if not isinstance(json_variable_length_arrays, str): + json_variable_length_arrays = json_variable_length_arrays.decode('utf8') + variable_length_arrays = json.loads(json_variable_length_arrays) + + else: + lengths = fixations.lengths + variable_length_arrays = ['x_hist', 'y_hist', 't_hist'] + [key for key in __attributes__ if key.endswith('_hist')] for key in __attributes__: - setattr(fixations, key, source[key][...]) + data = source[key][...] + + if key in variable_length_arrays: + data = VariableLengthArray(data, lengths) + + setattr(fixations, key, data) return fixations @@ -420,25 +488,6 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp N_trains = self.train_xs.shape[0] * self.train_xs.shape[1] - np.isnan(self.train_xs).sum() max_length_trains = self.train_xs.shape[1] - if scanpath_attributes is not None: - assert isinstance(scanpath_attributes, dict) - self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()} - for key, value in self.scanpath_attributes.items(): - assert len(value) == len(self.train_xs) - else: - self.scanpath_attributes = {} - - if scanpath_fixation_attributes is not None: - assert isinstance(scanpath_fixation_attributes, dict) - self.scanpath_fixation_attributes = {key: np.array(value) for key, value in scanpath_fixation_attributes.items()} - for key, value in self.scanpath_fixation_attributes.items(): - assert len(value) == len(self.train_xs) - else: - self.scanpath_fixation_attributes = {} - - self.scanpath_attribute_mapping = scanpath_attribute_mapping or {} - - # Create conditional fixations self.x = np.empty(N_trains) @@ -475,6 +524,29 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.t_hist[out_index][:fix_index] = self.train_ts[train_index][:fix_index] out_index += 1 + # TODO: this should become irrelevant once FixationTrains is also completely upgraded to VariableLengthArrays + self.x_hist = VariableLengthArray(self.x_hist, self.lengths) + self.y_hist = VariableLengthArray(self.y_hist, self.lengths) + self.t_hist = VariableLengthArray(self.t_hist, self.lengths) + + if scanpath_attributes is not None: + assert isinstance(scanpath_attributes, dict) + self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()} + for key, value in self.scanpath_attributes.items(): + assert len(value) == len(self.train_xs) + else: + self.scanpath_attributes = {} + + if scanpath_fixation_attributes is not None: + assert isinstance(scanpath_fixation_attributes, dict) + self.scanpath_fixation_attributes = {} + for key, value in scanpath_fixation_attributes.items(): + self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value) + else: + self.scanpath_fixation_attributes = {} + + self.scanpath_attribute_mapping = scanpath_attribute_mapping or {} + if attributes is None: attributes = {} @@ -509,7 +581,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp hist_attribute_name = new_attribute_name + '_hist' if hist_attribute_name in attributes: raise ValueError("attribute name clash: {hist_attribute_name}".format(hist_attribute_name=hist_attribute_name)) - attributes[hist_attribute_name] = np.full_like(self.x_hist, fill_value=np.nan) + attributes[hist_attribute_name] = np.full_like(self.x_hist._data, fill_value=np.nan) self.auto_attributes.append(hist_attribute_name) out_index = 0 @@ -520,6 +592,8 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp attributes[hist_attribute_name][out_index][:fix_index] = self.scanpath_fixation_attributes[attribute_name][train_index, :fix_index] out_index += 1 + attributes[hist_attribute_name] = VariableLengthArray(attributes[hist_attribute_name], self.lengths) + if attributes: self.__attributes__ = list(self.__attributes__) for key, value in attributes.items(): @@ -527,11 +601,25 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp assert key != 'scanpath_index' assert len(value) == len(self.x) self.__attributes__.append(key) - value = np.array(value) + if not isinstance(value, VariableLengthArray): + value = np.array(value) setattr(self, key, value) self.full_nonfixations = None + def _check_train_lengths(self, other: VariableLengthArray): + if not len(self.train_xs) == len(other): + raise ValueError("Length of scanpaths has to match") + if not np.all(self.train_lengths == other.lengths): + raise ValueError("Lengths of scanpaths have to match") + + def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray: + if not isinstance(data, VariableLengthArray): + data = VariableLengthArray(data, self.train_lengths) + + self._check_train_lengths(data) + + return data @classmethod def concatenate(cls, fixation_trains): @@ -645,7 +733,7 @@ def filter_fixation_trains(self, indices): train_ns = self.train_ns[indices] train_subjects = self.train_subjects[indices] scanpath_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_attributes.items()} - scanpath_fixation_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_fixation_attributes.items()} + scanpath_fixation_attributes = {key: value[indices] for key, value in self.scanpath_fixation_attributes.items()} scanpath_indices = np.arange(len(self.train_xs), dtype=int)[indices] fixation_indices = np.in1d(self.scanpath_index, scanpath_indices) @@ -907,12 +995,19 @@ def to_hdf5(self, target): """ target.attrs['type'] = np.string_('FixationTrains') - target.attrs['version'] = np.string_('1.2') + target.attrs['version'] = np.string_('1.3') + + variable_length_arrays = [] - for attribute in ['train_xs', 'train_ys', 'train_ts', 'train_ns', 'train_subjects'] + self.__attributes__: + for attribute in ['train_xs', 'train_ys', 'train_ts', 'train_ns', 'train_subjects', 'train_lengths'] + self.__attributes__: if attribute in ['subjects', 'scanpath_index'] + self.auto_attributes: continue - target.create_dataset(attribute, data=getattr(self, attribute)) + + data = getattr(self, attribute) + if isinstance(data, VariableLengthArray): + variable_length_arrays.append(attribute) + data = data._data + target.create_dataset(attribute, data=data) saved_attributes = [attribute_name for attribute_name in self.__attributes__ if attribute_name not in self.auto_attributes] target.attrs['__attributes__'] = np.string_(json.dumps(saved_attributes)) @@ -926,7 +1021,7 @@ def to_hdf5(self, target): scanpath_fixation_attributes_group = target.create_group('scanpath_fixation_attributes') for attribute_name, attribute_value in self.scanpath_fixation_attributes.items(): - scanpath_fixation_attributes_group.create_dataset(attribute_name, data=attribute_value) + scanpath_fixation_attributes_group.create_dataset(attribute_name, data=attribute_value._data) scanpath_fixation_attributes_group.attrs['__attributes__'] = np.string_(json.dumps(sorted(self.scanpath_fixation_attributes.keys()))) @@ -941,7 +1036,7 @@ def read_hdf5(cls, source): if data_type != 'FixationTrains': raise ValueError("Invalid type! Expected 'FixationTrains', got", data_type) - valid_versions = ['1.0', '1.1', '1.2'] + valid_versions = ['1.0', '1.1', '1.2', '1.3'] if data_version not in valid_versions: raise ValueError("Invalid version! Expected one of {}, got {}".format(', '.join(valid_versions), data_version)) @@ -972,6 +1067,15 @@ def read_hdf5(cls, source): data['scanpath_fixation_attributes'] = _load_attribute_dict_from_hdf5(source['scanpath_fixation_attributes']) data['scanpath_attribute_mapping'] = json.loads(decode_string(source.attrs['scanpath_attribute_mapping'])) + if data_version < '1.3': + train_lengths = np.array([len(remove_trailing_nans(data['train_xs'][i])) for i in range(len(data['train_xs']))]) + else: + train_lengths = source['train_lengths'][...] + + data['scanpath_fixation_attributes'] = { + key: VariableLengthArray(value, train_lengths) for key, value in data['scanpath_fixation_attributes'].items() + } + fixations = cls(**data) return fixations @@ -1597,6 +1701,11 @@ def concatenate_stimuli(stimuli): def concatenate_attributes(attributes): + attributes = list(attributes) + + if isinstance(attributes[0], VariableLengthArray): + return concatenate_variable_length_arrays(attributes) + attributes = [np.array(a) for a in attributes] for a in attributes: assert len(a.shape) == len(attributes[0].shape) diff --git a/pysaliency/torch_datasets.py b/pysaliency/torch_datasets.py index 6b2866e..7ad4a29 100644 --- a/pysaliency/torch_datasets.py +++ b/pysaliency/torch_datasets.py @@ -1,8 +1,8 @@ import random -from boltons.iterutils import chunked import numpy as np import torch +from boltons.iterutils import chunked from tqdm import tqdm from .models import Model @@ -126,7 +126,8 @@ def __call__(self, item): inds = np.array([y, x]) values = np.ones(len(y), dtype=int) - mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape) + # mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape) + mask = torch.sparse_coo_tensor(torch.tensor(inds), torch.tensor(values), shape, dtype=torch.int) mask = mask.coalesce() item['fixation_mask'] = mask diff --git a/tests/test_dataset_config.py b/tests/test_dataset_config.py index 6e3fe33..ad1d006 100644 --- a/tests/test_dataset_config.py +++ b/tests/test_dataset_config.py @@ -5,7 +5,7 @@ import numpy as np import pytest from imageio import imwrite -from test_datasets import compare_fixations, compare_fixation_trains +from test_datasets import assert_fixations_equal, assert_fixation_trains_equal from test_filter_datasets import assert_stimuli_equal import pysaliency @@ -129,7 +129,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_task(stimuli, filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -146,7 +146,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_multi_dim_att filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -163,7 +163,7 @@ def test_apply_dataset_filter_config_filter_fixations_by_attribute_subject_inver filtered_stimuli, filtered_fixations = dc.apply_dataset_filter_config(stimuli, fixations, filter_config) inds = [3, 4, 5, 6, 7] expected_fixations = fixations[inds] - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -180,7 +180,7 @@ def test_apply_dataset_filter_config_filter_stimuli_by_attribute_dva(file_stimul filtered_stimuli, filtered_fixations = dc.apply_dataset_filter_config(file_stimuli_with_attributes, fixations, filter_config) inds = [1] expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli) @@ -195,7 +195,7 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_length_multiple_inputs( filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) @@ -210,5 +210,5 @@ def test_apply_dataset_filter_config_filter_scanpaths_by_length_single_input(sti filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) assert_stimuli_equal(filtered_stimuli, stimuli) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6c12b73..e100fbd 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,9 +1,9 @@ from __future__ import absolute_import, division, print_function -from copy import deepcopy import os.path import pickle import unittest +from copy import deepcopy import dill import numpy as np @@ -15,6 +15,16 @@ import pysaliency from pysaliency.datasets import Fixations, FixationTrains, Scanpaths, Stimulus, check_prediction_shape, scanpaths_from_fixations +from pysaliency.utils.variable_length_array import VariableLengthArray + + +def assert_variable_length_array_equal(array1, array2): + assert isinstance(array1, VariableLengthArray) + assert isinstance(array2, VariableLengthArray) + assert len(array1) == len(array2) + + for i in range(len(array1)): + np.testing.assert_array_equal(array1[i], array2[i], err_msg=f'arrays not equal at index {i}') def compare_fixations_subset(f1, f2, f2_inds): @@ -31,7 +41,7 @@ def compare_fixations_subset(f1, f2, f2_inds): np.testing.assert_array_equal(getattr(f1, attribute), getattr(f2, attribute)[f2_inds]) -def compare_fixations(f1, f2, crop_length=False): +def assert_fixations_equal(f1, f2, crop_length=False): if crop_length: maximum_length = np.max(f2.lengths) else: @@ -39,9 +49,9 @@ def compare_fixations(f1, f2, crop_length=False): np.testing.assert_array_equal(f1.x, f2.x) np.testing.assert_array_equal(f1.y, f2.y) np.testing.assert_array_equal(f1.t, f2.t) - np.testing.assert_array_equal(f1.x_hist[:, :maximum_length], f2.x_hist[:, :maximum_length]) - np.testing.assert_array_equal(f1.y_hist[:, :maximum_length], f2.y_hist[:, :maximum_length]) - np.testing.assert_array_equal(f1.t_hist[:, :maximum_length], f2.t_hist[:, :maximum_length]) + assert_variable_length_array_equal(f1.x_hist, f2.x_hist) + assert_variable_length_array_equal(f1.y_hist, f2.y_hist) + assert_variable_length_array_equal(f1.t_hist, f2.t_hist) assert set(f1.__attributes__) == set(f2.__attributes__) for attribute in f1.__attributes__: @@ -50,14 +60,17 @@ def compare_fixations(f1, f2, crop_length=False): attribute1 = getattr(f1, attribute) attribute2 = getattr(f2, attribute) - if attribute.endswith('_hist'): + if isinstance(attribute1, VariableLengthArray): + assert_variable_length_array_equal(attribute1, attribute2) + continue + elif attribute.endswith('_hist'): attribute1 = attribute1[:, :maximum_length] attribute2 = attribute2[:, :maximum_length] np.testing.assert_array_equal(attribute1, attribute2, err_msg=f'attributes not equal: {attribute}') -def compare_fixation_trains(scanpaths1, scanpaths2): +def assert_fixation_trains_equal(scanpaths1, scanpaths2): np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs) np.testing.assert_array_equal(scanpaths1.train_ys, scanpaths2.train_ys) np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs) @@ -73,9 +86,9 @@ def compare_fixation_trains(scanpaths1, scanpaths2): assert scanpaths1.scanpath_fixation_attributes.keys() == scanpaths2.scanpath_fixation_attributes.keys() for attribute_name in scanpaths1.scanpath_fixation_attributes.keys(): - np.testing.assert_array_equal(scanpaths1.scanpath_fixation_attributes[attribute_name], scanpaths2.scanpath_fixation_attributes[attribute_name]) + assert_variable_length_array_equal(scanpaths1.scanpath_fixation_attributes[attribute_name], scanpaths2.scanpath_fixation_attributes[attribute_name]) - compare_fixations(scanpaths1, scanpaths2) + assert_fixations_equal(scanpaths1, scanpaths2) class TestFixations(TestWithData): @@ -121,7 +134,7 @@ def test_from_fixations(self): np.testing.assert_allclose(f.n, [0, 0, 0, 0, 0, 1, 1, 1]) np.testing.assert_allclose(f.subjects, [0, 0, 0, 1, 1, 1, 1, 1]) np.testing.assert_allclose(f.lengths, [0, 1, 2, 0, 1, 0, 1, 2]) - np.testing.assert_allclose(f.x_hist, [[np.nan, np.nan], + np.testing.assert_allclose(f.x_hist._data, [[np.nan, np.nan], [0, np.nan], [0, 1], [np.nan, np.nan], @@ -226,14 +239,15 @@ def test_save_and_load(self): np.testing.assert_allclose(f.n, [0, 0, 0, 0, 0, 1, 1, 1]) np.testing.assert_allclose(f.subjects, [0, 0, 0, 1, 1, 1, 1, 1]) np.testing.assert_allclose(f.lengths, [0, 1, 2, 0, 1, 0, 1, 2]) - np.testing.assert_allclose(f.x_hist, [[np.nan, np.nan], - [0, np.nan], - [0, 1], - [np.nan, np.nan], - [2, np.nan], - [np.nan, np.nan], - [1, np.nan], - [1, 5]]) + np.testing.assert_allclose(f.x_hist._data, + [[np.nan, np.nan], + [0, np.nan], + [0, 1], + [np.nan, np.nan], + [2, np.nan], + [np.nan, np.nan], + [1, np.nan], + [1, 5]]) class TestStimuli(TestWithData): @@ -403,13 +417,13 @@ def fixation_trains(): def test_copy_scanpaths(fixation_trains): copied_fixation_trains = fixation_trains.copy() - compare_fixation_trains(copied_fixation_trains, fixation_trains) + assert_fixation_trains_equal(copied_fixation_trains, fixation_trains) def test_copy_fixations(fixation_trains): fixations = fixation_trains[:] copied_fixations = fixations.copy() - compare_fixations(copied_fixations, fixations) + assert_fixations_equal(copied_fixations, fixations) def test_write_read_scanpaths_pathlib(tmp_path, fixation_trains): @@ -420,7 +434,7 @@ def test_write_read_scanpaths_pathlib(tmp_path, fixation_trains): # make sure there is no sophisticated caching... assert fixation_trains is not new_fixation_trains - compare_fixation_trains(fixation_trains, new_fixation_trains) + assert_fixation_trains_equal(fixation_trains, new_fixation_trains) def test_write_read_scanpaths(tmp_path, fixation_trains): @@ -431,7 +445,7 @@ def test_write_read_scanpaths(tmp_path, fixation_trains): # make sure there is no sophisticated caching... assert fixation_trains is not new_fixation_trains - compare_fixation_trains(fixation_trains, new_fixation_trains) + assert_fixation_trains_equal(fixation_trains, new_fixation_trains) def test_scanpath_lengths(fixation_trains): @@ -447,24 +461,34 @@ def test_scanpath_attributes(fixation_trains): def test_scanpath_fixation_attributes(fixation_trains): + # test attribute itself assert "durations" in fixation_trains.scanpath_fixation_attributes - np.testing.assert_array_equal( - fixation_trains.scanpath_fixation_attributes['durations'], - np.array([ - [42, 25, 100], - [99, 98, np.nan], - [200, 150, 120] - ]) - ) + assert isinstance(fixation_trains.scanpath_fixation_attributes['durations'], VariableLengthArray) + np.testing.assert_array_equal(fixation_trains.scanpath_fixation_attributes['durations'][0], [42, 25, 100]) + np.testing.assert_array_equal(fixation_trains.scanpath_fixation_attributes['durations'][1], [99, 98]) + np.testing.assert_array_equal(fixation_trains.scanpath_fixation_attributes['durations'][2], [200, 150, 120]) + # test derived fixation attribute assert "duration" in fixation_trains.__attributes__ np.testing.assert_array_equal(fixation_trains.duration, np.array([ 42, 25, 100, 99, 98, 200, 150, 120 ])) + + # test derived history attribute assert "duration_hist" in fixation_trains.__attributes__ - np.testing.assert_array_equal(fixation_trains.duration_hist[6], [200, np.nan]) + assert isinstance(fixation_trains.duration_hist, VariableLengthArray) + np.testing.assert_array_equal(fixation_trains.duration_hist[0], []) + np.testing.assert_array_equal(fixation_trains.duration_hist[1], [42]) + np.testing.assert_array_equal(fixation_trains.duration_hist[2], [42, 25]) + + np.testing.assert_array_equal(fixation_trains.duration_hist[3], []) + np.testing.assert_array_equal(fixation_trains.duration_hist[4], [99]) + + np.testing.assert_array_equal(fixation_trains.duration_hist[5], []) + np.testing.assert_array_equal(fixation_trains.duration_hist[6], [200]) + np.testing.assert_array_equal(fixation_trains.duration_hist[7], [200, 150]) @pytest.mark.parametrize('scanpath_indices,fixation_indices', [ @@ -500,12 +524,12 @@ def test_filter_fixation_trains(fixation_trains, scanpath_indices, fixation_indi fixation_trains.scanpath_attributes['task'][scanpath_indices] ) - np.testing.assert_array_equal( + assert_variable_length_array_equal( sub_fixations.scanpath_fixation_attributes['durations'], fixation_trains.scanpath_fixation_attributes['durations'][scanpath_indices] ) - compare_fixations(sub_fixations, fixation_trains[fixation_indices]) + assert_fixations_equal(sub_fixations, fixation_trains[fixation_indices]) def test_read_hdf5_caching(fixation_trains, tmp_path): @@ -523,7 +547,7 @@ def test_read_hdf5_caching(fixation_trains, tmp_path): def test_fixation_trains_copy(fixation_trains): copied_fixation_trains = fixation_trains.copy() assert isinstance(copied_fixation_trains, FixationTrains) - compare_fixations(fixation_trains, copied_fixation_trains) + assert_fixations_equal(fixation_trains, copied_fixation_trains) def test_fixations_copy(fixation_trains): @@ -531,7 +555,19 @@ def test_fixations_copy(fixation_trains): assert isinstance(fixations, Fixations) copied_fixations = fixations.copy() assert isinstance(copied_fixations, Fixations) - compare_fixations(fixations, copied_fixations) + assert_fixations_equal(fixations, copied_fixations) + + +def test_fixations_save_load(tmp_path, fixation_trains): + fixations = fixation_trains[:-1] + + assert isinstance(fixations, Fixations) + + filename = tmp_path / 'fixations.hdf5' + fixations.to_hdf5(filename) + new_fixations = pysaliency.read_hdf5(filename) + + assert_fixations_equal(fixations, new_fixations) @pytest.fixture @@ -650,6 +686,28 @@ def test_concatenate_file_stimuli(file_stimuli_with_attributes): def test_concatenate_fixations(fixation_trains): + fixations = fixation_trains[:] + new_fixations = pysaliency.Fixations.concatenate((fixations, fixations)) + assert isinstance(new_fixations, pysaliency.Fixations) + np.testing.assert_allclose( + new_fixations.x, + np.concatenate((fixation_trains.x, fixation_trains.x)) + ) + + np.testing.assert_allclose( + new_fixations.n, + np.concatenate((fixation_trains.n, fixation_trains.n)) + ) + + assert new_fixations.__attributes__ == ['subjects', 'duration', 'duration_hist', 'multi_dim_attribute', 'scanpath_index', 'some_attribute', 'task'] + + np.testing.assert_allclose( + new_fixations.some_attribute, + np.concatenate((fixation_trains.some_attribute, fixation_trains.some_attribute)) + ) + + +def test_concatenate_fixation_trains(fixation_trains): new_fixations = pysaliency.Fixations.concatenate((fixation_trains, fixation_trains)) assert isinstance(new_fixations, pysaliency.Fixations) np.testing.assert_allclose( @@ -777,7 +835,7 @@ def test_scanpaths_from_fixations(fixation_indices): new_scanpaths, new_indices = scanpaths_from_fixations(sub_fixations) new_sub_fixations = new_scanpaths[new_indices] - compare_fixations(sub_fixations, new_sub_fixations, crop_length=True) + assert_fixations_equal(sub_fixations, new_sub_fixations, crop_length=True) def test_check_prediction_shape(): diff --git a/tests/test_filter_datasets.py b/tests/test_filter_datasets.py index f45649a..f7ce463 100644 --- a/tests/test_filter_datasets.py +++ b/tests/test_filter_datasets.py @@ -7,7 +7,7 @@ import pysaliency import pysaliency.filter_datasets as filter_datasets from pysaliency.filter_datasets import filter_fixations_by_attribute, filter_stimuli_by_attribute, filter_scanpaths_by_attribute, filter_scanpaths_by_length, create_subset, remove_stimuli_without_fixations -from test_datasets import compare_fixations, compare_fixation_trains +from test_datasets import assert_fixations_equal, assert_fixation_trains_equal @pytest.fixture @@ -350,7 +350,7 @@ def test_filter_stimuli_by_attribute_dva(file_stimuli_with_attributes, fixation_ filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value) inds = [1] expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli) @@ -361,7 +361,7 @@ def test_filter_stimuli_by_attribute_multiple_values(file_stimuli_with_attribute filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_values=attribute_values) inds = [1, 2] expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli) @@ -373,7 +373,7 @@ def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match=invert_match) inds = list(range(0, 13)) + list(range(14, 18)) expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli) @@ -385,7 +385,7 @@ def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains): filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) inds = [3, 4, 5, 6, 7] expected_fixations = fixations[inds] - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) def test_filter_fixations_by_attribute_some_attribute(fixation_trains): @@ -396,7 +396,7 @@ def test_filter_fixations_by_attribute_some_attribute(fixation_trains): filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) inds = [2] expected_fixations = fixations[inds] - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trains): @@ -407,7 +407,7 @@ def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trai filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) inds = list(range(0, 3)) + list(range(4, 8)) expected_fixations = fixations[inds] - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) def test_filter_scanpaths_by_attribute_task(fixation_trains): @@ -418,7 +418,7 @@ def test_filter_scanpaths_by_attribute_task(fixation_trains): filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): @@ -429,7 +429,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation_trains): @@ -440,7 +440,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) inds = [1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) @pytest.mark.parametrize('intervals', [([(1, 2), (2, 3)]), ([(2, 3), (3, 4)]), ([(2)]), ([(3)])]) @@ -450,19 +450,19 @@ def test_filter_scanpaths_by_length(fixation_trains, intervals): if intervals == [(1, 2), (2, 3)]: inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) if intervals == [(2, 3), (3, 4)]: inds = [0, 1, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) if intervals == [(2)]: inds = [1] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) if intervals == [(3)]: inds = [0, 2] expected_scanpaths = scanpaths.filter_fixation_trains(inds) - compare_fixation_trains(filtered_scanpaths, expected_scanpaths) + assert_fixation_trains_equal(filtered_scanpaths, expected_scanpaths) def test_remove_stimuli_without_fixations(file_stimuli_with_attributes, fixation_trains): @@ -470,5 +470,5 @@ def test_remove_stimuli_without_fixations(file_stimuli_with_attributes, fixation filtered_stimuli, filtered_fixations = remove_stimuli_without_fixations(file_stimuli_with_attributes, fixations) inds = [0, 1] expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) - compare_fixations(filtered_fixations, expected_fixations) + assert_fixations_equal(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli)