Skip to content

Commit

Permalink
Bugfix (#66)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k authored Apr 20, 2024
1 parent 4137c03 commit cb04d7f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pysaliency/datasets/scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def concatenate_scanpaths(scanpaths_list: List[Scanpaths]) -> Scanpaths:
mappings = {scanpaths.attribute_mapping.get(key) for scanpaths in scanpaths_list}
if len(mappings) > 1:
raise ValueError(f"Multiple mappings for attribute {key} found: {mappings}")
elif len(mappings) == 1:
elif len(mappings) == 1 and list(mappings)[0] is not None:
merged_attribute_mapping[key] = mappings.pop()

return Scanpaths(xs, ys, n, length, scanpath_attributes=scanpath_attributes, fixation_attributes=fixation_attributes, attribute_mapping=merged_attribute_mapping)
32 changes: 32 additions & 0 deletions tests/datasets/test_scanpaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,38 @@ def test_concatenate_scanpaths():
assert concatenated_scanpaths.attribute_mapping == {'attribute1': 'attr1', 'attribute2': 'attr2'}


def test_concatenate_scanpaths_no_mapping():
xs1 = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys1 = [[10, 11, 12], [12, 11], [21, 25, 33]]
n1 = [0, 0, 1]
scanpath_attributes1 = {'task': [0, 1, 0]}
fixation_attributes1 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping1 = {}

scanpaths1 = Scanpaths(xs1, ys1, n1, length=None, scanpath_attributes=scanpath_attributes1, fixation_attributes=fixation_attributes1, attribute_mapping=attribute_mapping1)

xs2 = [[0, 1, 2], [2, 2], [1, 5, 4]]
ys2 = [[10, 11, 12], [12, 12], [21, 25, 33]]
n2 = [0, 1, 1]
scanpath_attributes2 = {'task': [0, 1, 0]}
fixation_attributes2 = {'attribute1': [[1, 1, 2], [2, 2], [0, 1, 3]], 'attribute2': [[3, 1.3, 5], [1, 42], [0, -1, -3]]}
attribute_mapping2 = {}

scanpaths2 = Scanpaths(xs2, ys2, n2, length=None, scanpath_attributes=scanpath_attributes2, fixation_attributes=fixation_attributes2, attribute_mapping=attribute_mapping2)

concatenated_scanpaths = concatenate_scanpaths([scanpaths1, scanpaths2])

assert_variable_length_array_equal(concatenated_scanpaths.xs, VariableLengthArray(xs1 + xs2))
assert_variable_length_array_equal(concatenated_scanpaths.ys, VariableLengthArray(ys1 + ys2))
np.testing.assert_array_equal(concatenated_scanpaths.n, np.array(n1 + n2))
assert concatenated_scanpaths.scanpath_attributes.keys() == {'task'}
np.testing.assert_array_equal(concatenated_scanpaths.scanpath_attributes['task'], np.array([0, 1, 0, 0, 1, 0]))
assert concatenated_scanpaths.fixation_attributes.keys() == {'attribute1', 'attribute2'}
assert_variable_length_array_equal(concatenated_scanpaths.fixation_attributes['attribute1'], VariableLengthArray([[1, 1, 2], [2, 2], [0, 1, 3], [1, 1, 2], [2, 2], [0, 1, 3]]))
assert_variable_length_array_equal(concatenated_scanpaths.fixation_attributes['attribute2'], VariableLengthArray([[3, 1.3, 5], [1, 42], [0, -1, -3], [3, 1.3, 5], [1, 42], [0, -1, -3]]))
assert concatenated_scanpaths.attribute_mapping == {}


def test_concatenate_scanpaths_missing_fixation_attribute():
xs1 = [[0, 1, 2], [2, 2], [1, 5, 3]]
ys1 = [[10, 11, 12], [12, 11], [21, 25, 33]]
Expand Down

0 comments on commit cb04d7f

Please sign in to comment.