Skip to content

Commit

Permalink
Allow keyword specification of scanpath attributes
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 9, 2024
1 parent 6d6e304 commit 8ff7f59
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
30 changes: 22 additions & 8 deletions pysaliency/datasets/scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from typing import Dict, List, Optional, Union

import numpy as np
from boltons.cacheutils import cached

from ..utils.variable_length_array import VariableLengthArray, concatenate_variable_length_arrays
from .utils import create_hdf5_dataset, decode_string, get_merged_attribute_list, hdf5_wrapper, _load_attribute_dict_from_hdf5
from .utils import _load_attribute_dict_from_hdf5, create_hdf5_dataset, decode_string, get_merged_attribute_list, hdf5_wrapper


class Scanpaths(object):
Expand Down Expand Up @@ -34,7 +33,8 @@ def __init__(self,
lengths=None,
scanpath_attributes: Optional[Dict[str, np.ndarray]] = None,
fixation_attributes: Optional[Dict[str, Union[np.ndarray, VariableLengthArray]]]=None,
attribute_mapping=Dict[str, str]):
attribute_mapping=Dict[str, str],
**kwargs):

self.n = np.asarray(n)

Expand All @@ -54,9 +54,27 @@ def __init__(self,
if not len(self.xs) == len(self.ys) == len(self.n):
raise ValueError("Length of xs, ys, ts and n has to match")

scanpath_attributes = scanpath_attributes or {}
fixation_attributes = fixation_attributes or {}
self.attribute_mapping = attribute_mapping or {}


for key, value in kwargs.items():
if not len(value) == len(self.xs):
raise ValueError(f"Length of attribute {key} has to match number of scanpaths, but got {len(value)} != {len(self.xs)}")
if isinstance(value, VariableLengthArray) or isinstance(value[0], (list, np.ndarray)):
if key in fixation_attributes:
raise ValueError(f"Attribute {key} already exists in fixation_attributes")
fixation_attributes[key] = self._as_variable_length_array(value)
if key not in self.attribute_mapping and key[-1] == 's':
self.attribute_mapping[key] = key[:-1]
else:
if key in scanpath_attributes:
raise ValueError(f"Attribute {key} already exists in scanpath_attributes")
scanpath_attributes[key] = np.array(value)

# setting scanpath attributes

scanpath_attributes = scanpath_attributes or {}
self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()}

for key, value in self.scanpath_attributes.items():
Expand All @@ -65,12 +83,8 @@ def __init__(self,

# setting fixation attributes

fixation_attributes = fixation_attributes or {}

self.fixation_attributes = {key: self._as_variable_length_array(value) for key, value in fixation_attributes.items()}

self.attribute_mapping = attribute_mapping or {}

def _check_lengths(self, other: VariableLengthArray):
if not len(self) == len(other):
raise ValueError("Length of scanpaths has to match")
Expand Down
32 changes: 32 additions & 0 deletions tests/datasets/test_scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,38 @@ def test_scanpaths_from_lists():
assert scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2'}


def test_scanpaths_from_lists_with_keyword_arguments():
xs = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys = [[10, 11, 12], [12, 12], [21, 25, 33]]
ts = [[1, 2.5, 4], [2, 3], [3, 4, 6]]
subject = [0, 1, 2]
n = [0, 0, 1]
expected_lengths = np.array([3, 2, 3])
scanpath_attributes = {'task': [0, 1, 0]}
fixation_attributes = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping = {'attribute1': 'attr1', 'attribute2': 'attr2'}

scanpaths = Scanpaths(
xs,
ys,
n,
lengths=None, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=attribute_mapping,
ts=ts,
subject=subject,
)

np.testing.assert_array_equal(scanpaths.xs._data, np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]]))
np.testing.assert_array_equal(scanpaths.ys._data, np.array([[10, 11, 12], [12, 12, np.nan], [21, 25, 33]]))
np.testing.assert_array_equal(scanpaths.n, n)
np.testing.assert_array_equal(scanpaths.lengths, expected_lengths)
np.testing.assert_array_equal(scanpaths.scanpath_attributes['task'], np.array([0, 1, 0]))
np.testing.assert_array_equal(scanpaths.scanpath_attributes['subject'], np.array([0, 1, 2]))
np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute1']._data, np.array([[1, 1, 2], [2, 2, np.nan], [0, 1, 3]]))
np.testing.assert_array_equal(scanpaths.fixation_attributes['attribute2']._data, np.array([[3, 1.3, 5], [1, 42, np.nan], [0, -1, -3]]))
np.testing.assert_array_equal(scanpaths.fixation_attributes['ts']._data, np.array([[1, 2.5, 4], [2, 3, np.nan], [3, 4, 6]]))
assert scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2', 'ts': 't'}


def test_scanpaths_init_inconsistent_lengths():
xs = np.array([[0, 1, 2], [2, 2, np.nan], [1, 5, 3]])
ys = np.array([[10, 11, 12], [12, 12, np.nan]]) # too short, should fail
Expand Down

0 comments on commit 8ff7f59

Please sign in to comment.