Skip to content

Commit

Permalink
Bugfix: create_dataset returned wrong stimulus indices
Browse files Browse the repository at this point in the history
only happened for ScanpathFixations. Fixations and FixationTrains
were correct. For ScanpathFixations, the `n` attributes of scanpaths
and fixations were not updated to reflect the new indices.

Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Jun 24, 2024
1 parent 8fe67e9 commit 6d629d0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions pysaliency/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_subset(stimuli, fixations, stimuli_indices):
new_image_indices = [new_pos[i] for i in fixations.scanpaths.n[scanpath_inds]]

new_scanpaths = fixations.scanpaths[scanpath_inds]
new_scanpaths.n = np.array(new_image_indices)

new_fixations = ScanpathFixations(
scanpaths=new_scanpaths,
Expand Down
16 changes: 10 additions & 6 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ def test_create_subset_scanpath_fixations(file_stimuli_with_attributes, scanpath
sub_stimuli, sub_fixations = pysaliency.datasets.create_subset(file_stimuli_with_attributes, scanpath_fixations, stimulus_indices)

expected_sub_fixations = scanpath_fixations.filter_scanpaths(scanpath_indices).copy()
expected_sub_fixations.scanpaths.n = sub_fixations.scanpaths.n
expected_sub_fixations.n = sub_fixations.n
expected_scanpath_n = np.array([stimulus_indices.index(n) for n in expected_sub_fixations.scanpaths.n])
expected_sub_fixations.scanpaths.n = expected_scanpath_n
expected_fixation_n = np.array([stimulus_indices.index(n) for n in expected_sub_fixations.n])
expected_sub_fixations.n = expected_fixation_n

assert_scanpath_fixations_equal(sub_fixations, expected_sub_fixations)

Expand All @@ -132,8 +134,10 @@ def test_create_subset_fixation_trains(file_stimuli_with_attributes, fixation_tr
sub_stimuli, sub_fixations = pysaliency.datasets.create_subset(file_stimuli_with_attributes, fixation_trains, stimulus_indices)

expected_sub_fixations= fixation_trains.filter_scanpaths(scanpath_indices).copy()
expected_sub_fixations.scanpaths.n = sub_fixations.scanpaths.n
expected_sub_fixations.n = sub_fixations.n
expected_scanpath_n = np.array([stimulus_indices.index(n) for n in expected_sub_fixations.scanpaths.n])
expected_sub_fixations.scanpaths.n = expected_scanpath_n
expected_fixation_n = np.array([stimulus_indices.index(n) for n in expected_sub_fixations.n])
expected_sub_fixations.n = expected_fixation_n

assert_fixation_trains_equal(sub_fixations, expected_sub_fixations)

Expand All @@ -148,8 +152,8 @@ def test_create_subset_fixations(file_stimuli_with_attributes, fixation_trains,
sub_stimuli, sub_fixations = pysaliency.datasets.create_subset(file_stimuli_with_attributes, fixations, stimulus_indices)

expected_sub_fixations= fixations[fixation_indices].copy()
expected_sub_fixations.n = sub_fixations.n

expected_fixation_n = np.array([stimulus_indices.index(n) for n in expected_sub_fixations.n])
expected_sub_fixations.n = expected_fixation_n
assert not isinstance(sub_fixations, pysaliency.FixationTrains)
assert_fixations_equal(sub_fixations, expected_sub_fixations)

Expand Down
1 change: 1 addition & 0 deletions tests/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def assert_fixations_equal(f1, f2, crop_length=False):
np.testing.assert_array_equal(f1.x, f2.x)
np.testing.assert_array_equal(f1.y, f2.y)
np.testing.assert_array_equal(f1.t, f2.t)
np.testing.assert_array_equal(f1.n, f2.n)
assert_variable_length_array_equal(f1.x_hist, f2.x_hist)
assert_variable_length_array_equal(f1.y_hist, f2.y_hist)
assert_variable_length_array_equal(f1.t_hist, f2.t_hist)
Expand Down

0 comments on commit 6d629d0

Please sign in to comment.