Skip to content

Commit

Permalink
Fix concatenate datasets (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k authored Apr 16, 2024
1 parent 8259b2d commit 4137c03
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pysaliency/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def concatenate_datasets(stimuli, fixations):
offset = sum(len(s) for s in stimuli[:i])
f = fixations[i].copy()
f.n += offset
if isinstance(f, FixationTrains):
f.train_ns += offset
if isinstance(f, ScanpathFixations):
f.scanpaths.n += offset
fixations[i] = f

return concatenate_stimuli(stimuli), concatenate_fixations(fixations)
Expand Down
50 changes: 49 additions & 1 deletion tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,52 @@ def test_create_subset_numpy_mask(file_stimuli_with_attributes, fixation_trains)

assert isinstance(sub_fixations, pysaliency.FixationTrains)
assert len(sub_stimuli) == 2
np.testing.assert_array_equal(sub_fixations.x, fixation_trains.x[np.isin(fixation_trains.n, [0, 2])])
np.testing.assert_array_equal(sub_fixations.x, fixation_trains.x[np.isin(fixation_trains.n, [0, 2])])


def test_concatenate_datasets_with_scanpath_fixations(file_stimuli_with_attributes, scanpath_fixations):
stimuli1 = file_stimuli_with_attributes
stimuli2 = file_stimuli_with_attributes

fixations1 = scanpath_fixations
fixations2 = scanpath_fixations

stimuli_list = [stimuli1, stimuli2]
fixations_list = [fixations1, fixations2]

concatenated_stimuli, concatenated_fixations = pysaliency.datasets.concatenate_datasets(stimuli_list, fixations_list)

modified_fixations2 = fixations2.copy()
modified_fixations2.n += len(stimuli1)
modified_fixations2.scanpaths.n += len(stimuli1)

expected_fixations = pysaliency.datasets.concatenate_fixations([fixations1, modified_fixations2])


assert len(concatenated_stimuli) == len(stimuli1) + len(stimuli2)

assert_scanpath_fixations_equal(concatenated_fixations, expected_fixations)


def test_concatenate_datasets_with_fixation_trains(file_stimuli_with_attributes, fixation_trains):
stimuli1 = file_stimuli_with_attributes
stimuli2 = file_stimuli_with_attributes

fixations1 = fixation_trains
fixations2 = fixation_trains

stimuli_list = [stimuli1, stimuli2]
fixations_list = [fixations1, fixations2]

concatenated_stimuli, concatenated_fixations = pysaliency.datasets.concatenate_datasets(stimuli_list, fixations_list)

modified_fixations2 = fixations2.copy()
modified_fixations2.n += len(stimuli1)
modified_fixations2.scanpaths.n += len(stimuli1)

expected_fixations = pysaliency.datasets.concatenate_fixations([fixations1, modified_fixations2])


assert len(concatenated_stimuli) == len(stimuli1) + len(stimuli2)

assert_fixation_trains_equal(concatenated_fixations, expected_fixations)

0 comments on commit 4137c03

Please sign in to comment.