Skip to content

Commit

Permalink
Copy and concatenate scanpaths
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 9, 2024
1 parent 9bd8a9b commit 82748fb
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 19 deletions.
37 changes: 35 additions & 2 deletions pysaliency/datasets/scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from boltons.cacheutils import cached

from ..utils.variable_length_array import VariableLengthArray, concatenate_variable_length_arrays
from .utils import create_hdf5_dataset, decode_string, hdf5_wrapper, _load_attribute_dict_from_hdf5
from .utils import create_hdf5_dataset, decode_string, get_merged_attribute_list, hdf5_wrapper, _load_attribute_dict_from_hdf5


class Scanpaths(object):
Expand Down Expand Up @@ -174,4 +174,37 @@ def __getitem__(self, index):
return type(self)(self.xs[index], self.ys[index], self.n[index], self.lengths[index],
scanpath_attributes={key: value[index] for key, value in self.scanpath_attributes.items()},
fixation_attributes={key: value[index] for key, value in self.fixation_attributes.items()},
attribute_mapping=self.attribute_mapping)
attribute_mapping=self.attribute_mapping)

def copy(self) -> 'Scanpaths':
return type(self)(self.xs.copy(), self.ys.copy(), self.n.copy(), self.lengths.copy(),
scanpath_attributes={key: value.copy() for key, value in self.scanpath_attributes.items()},
fixation_attributes={key: value.copy() for key, value in self.fixation_attributes.items()},
attribute_mapping=self.attribute_mapping.copy())

@classmethod
def concatenate(cls, scanpaths_list: List['Scanpaths']) -> 'Scanpaths':
return concatenate_scanpaths(scanpaths_list)


def concatenate_scanpaths(scanpaths_list: List[Scanpaths]) -> Scanpaths:
xs = concatenate_variable_length_arrays([scanpaths.xs for scanpaths in scanpaths_list])
ys = concatenate_variable_length_arrays([scanpaths.ys for scanpaths in scanpaths_list])
n = np.concatenate([scanpaths.n for scanpaths in scanpaths_list])
lengths = np.concatenate([scanpaths.lengths for scanpaths in scanpaths_list])

merged_scanpath_attributes = get_merged_attribute_list([scanpaths.scanpath_attributes.keys() for scanpaths in scanpaths_list])
scanpath_attributes = {key: np.concatenate([scanpaths.scanpath_attributes[key] for scanpaths in scanpaths_list]) for key in merged_scanpath_attributes}

merged_fixation_attributes = get_merged_attribute_list([scanpaths.fixation_attributes.keys() for scanpaths in scanpaths_list])
fixation_attributes = {key: concatenate_variable_length_arrays([scanpaths.fixation_attributes[key] for scanpaths in scanpaths_list]) for key in merged_fixation_attributes}

merged_attribute_mapping = {}
for key in merged_fixation_attributes:
mappings = {scanpaths.attribute_mapping.get(key) for scanpaths in scanpaths_list}
if len(mappings) > 1:
raise ValueError(f"Multiple mappings for attribute {key} found: {mappings}")
elif len(mappings) == 1:
merged_attribute_mapping[key] = mappings.pop()

return Scanpaths(xs, ys, n, lengths, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=merged_attribute_mapping)
15 changes: 0 additions & 15 deletions pysaliency/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,6 @@ def _load_attribute_dict_from_hdf5(attribute_group):
return attributes


def get_merged_attribute_list(attributes):
all_attributes = set(attributes[0])
common_attributes = set(attributes[0])

for _attributes in attributes[1:]:
all_attributes = all_attributes.union(_attributes)
common_attributes = common_attributes.intersection(_attributes)

if common_attributes != all_attributes:
lost_attributes = all_attributes.difference(common_attributes)
warnings.warn(f"Discarding attributes which are not present everywhere: {lost_attributes}", stacklevel=4)

return sorted(common_attributes)


def concatenate_attributes(attributes):
attributes = list(attributes)

Expand Down
112 changes: 110 additions & 2 deletions tests/datasets/test_scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import pysaliency
from pysaliency.datasets import Scanpaths
from pysaliency.datasets import Scanpaths, concatenate_scanpaths
from pysaliency.utils.variable_length_array import VariableLengthArray


Expand Down Expand Up @@ -216,4 +216,112 @@ def test_write_read_scanpaths_pathlib(tmp_path):
new_scanpaths = pysaliency.read_hdf5(filename)

assert scanpaths is not new_scanpaths # make sure there is no sophisticated caching...
assert_scanpaths_equal(scanpaths, new_scanpaths)
assert_scanpaths_equal(scanpaths, new_scanpaths)


def test_scanpaths_copy():
xs = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys = [[10, 11, 12], [12, 12], [21, 25, 33]]
n = [0, 0, 1]
scanpath_attributes = {'task': [0, 1, 0]}
fixation_attributes = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths = Scanpaths(xs, ys, n, lengths=None, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=attribute_mapping)

new_scanpaths = scanpaths.copy()

assert_scanpaths_equal(scanpaths, new_scanpaths)
assert scanpaths is not new_scanpaths


def test_concatenate_scanpaths():
xs1 = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys1 = [[10, 11, 12], [12, 11], [21, 25, 33]]
n1 = [0, 0, 1]
scanpath_attributes1 = {'task': [0, 1, 0]}
fixation_attributes1 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping1 = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths1 = Scanpaths(xs1, ys1, n1, lengths=None, scanpath_attributes=scanpath_attributes1, fixation_attributes=fixation_attributes1, attribute_mapping=attribute_mapping1)

xs2 = [[0, 1, 2], [2, 2], [1, 5, 4]]
ys2 = [[10, 11, 12], [12, 12], [21, 25, 33]]
n2 = [0, 1, 1]
scanpath_attributes2 = {'task': [0, 1, 0]}
fixation_attributes2 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping2 = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths2 = Scanpaths(xs2, ys2, n2, lengths=None, scanpath_attributes=scanpath_attributes2, fixation_attributes=fixation_attributes2, attribute_mapping=attribute_mapping2)

concatenated_scanpaths = concatenate_scanpaths([scanpaths1, scanpaths2])

assert_variable_length_array_equal(concatenated_scanpaths.xs, VariableLengthArray(xs1 + xs2))
assert_variable_length_array_equal(concatenated_scanpaths.ys, VariableLengthArray(ys1 + ys2))
np.testing.assert_array_equal(concatenated_scanpaths.n, np.array(n1 + n2))
assert concatenated_scanpaths.scanpath_attributes.keys() == {'task'}
np.testing.assert_array_equal(concatenated_scanpaths.scanpath_attributes['task'], np.array([0, 1, 0, 0, 1, 0]))
assert concatenated_scanpaths.fixation_attributes.keys() == {'attribute1', 'attribute2'}
assert_variable_length_array_equal(concatenated_scanpaths.fixation_attributes['attribute1'], VariableLengthArray([[1, 1, 2], [2, 2], [0, 1, 3], [1, 1, 2], [2, 2], [0, 1, 3]]))
assert_variable_length_array_equal(concatenated_scanpaths.fixation_attributes['attribute2'], VariableLengthArray([[3, 1.3, 5], [1, 42], [0, -1, -3], [3, 1.3, 5], [1, 42], [0, -1, -3]]))
assert concatenated_scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2'}


def test_concatenate_scanpaths_missing_fixation_attribute():
xs1 = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys1 = [[10, 11, 12], [12, 11], [21, 25, 33]]
n1 = [0, 0, 1]
scanpath_attributes1 = {'task': [0, 1, 0]}
fixation_attributes1 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping1 = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths1 = Scanpaths(xs1, ys1, n1, lengths=None, scanpath_attributes=scanpath_attributes1, fixation_attributes=fixation_attributes1, attribute_mapping=attribute_mapping1)

xs2 = [[0, 1, 2], [2, 2], [1, 5, 4]]
ys2 = [[10, 11, 12], [12, 12], [21, 25, 33]]
n2 = [0, 1, 1]
scanpath_attributes2 = {'task': [0, 1, 0]}
fixation_attributes2 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]]}
attribute_mapping2 = {'attribute1': 'attr1'}

scanpaths2 = Scanpaths(xs2, ys2, n2, lengths=None, scanpath_attributes=scanpath_attributes2, fixation_attributes=fixation_attributes2, attribute_mapping=attribute_mapping2)

concatenated_scanpaths = concatenate_scanpaths([scanpaths1, scanpaths2])

assert_variable_length_array_equal(concatenated_scanpaths.xs, VariableLengthArray(xs1 + xs2))
assert_variable_length_array_equal(concatenated_scanpaths.ys, VariableLengthArray(ys1 + ys2))
np.testing.assert_array_equal(concatenated_scanpaths.n, np.array(n1 + n2))
assert concatenated_scanpaths.scanpath_attributes.keys() == {'task'}
np.testing.assert_array_equal(concatenated_scanpaths.scanpath_attributes['task'], np.array([0, 1, 0, 0, 1, 0]))
assert concatenated_scanpaths.fixation_attributes.keys() == {'attribute1'}
assert_variable_length_array_equal(concatenated_scanpaths.fixation_attributes['attribute1'], VariableLengthArray([[1, 1, 2], [2, 2], [0, 1, 3], [1, 1, 2], [2, 2], [0, 1, 3]]))
assert concatenated_scanpaths.attribute_mapping == {'attribute1': 'attr1'}


def test_concatenate_scanpaths_inconsistent_attribute_mappings():
xs1 = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys1 = [[10, 11, 12], [12, 11], [21, 25, 33]]
n1 = [0, 0, 1]
scanpath_attributes1 = {'task': [0, 1, 0]}
fixation_attributes1 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping1 = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths1 = Scanpaths(xs1, ys1, n1, lengths=None, scanpath_attributes=scanpath_attributes1, fixation_attributes=fixation_attributes1, attribute_mapping=attribute_mapping1)

xs2 = [[0, 1, 2], [2, 2], [1, 5, 4]]
ys2 = [[10, 11, 12], [12, 12], [21, 25, 33]]
n2 = [0, 1, 1]
scanpath_attributes2 = {'task': [0, 1, 0]}
fixation_attributes2 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping2 = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths2_clean = Scanpaths(xs2, ys2, n2, lengths=None, scanpath_attributes=scanpath_attributes2, fixation_attributes=fixation_attributes2, attribute_mapping=attribute_mapping2)

# make sure that everyuthing else wearks
concatenate_scanpaths([scanpaths1, scanpaths2_clean])

attribute_mapping2 = {'attribute1': 'attr1', 'attribute2': 'attr3'}
scanpaths2_inconsistent = Scanpaths(xs2, ys2, n2, lengths=None, scanpath_attributes=scanpath_attributes2, fixation_attributes=fixation_attributes2, attribute_mapping=attribute_mapping2)

with pytest.raises(ValueError):
concatenate_scanpaths([scanpaths1, scanpaths2_inconsistent])

0 comments on commit 82748fb

Please sign in to comment.