From b2cb246fac4d9868e2e70200488722378c367c45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Wed, 17 Apr 2024 00:12:37 +0200 Subject: [PATCH] Fix concatenate datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/datasets/__init__.py | 4 +-- tests/datasets/test_datasets.py | 50 ++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/pysaliency/datasets/__init__.py b/pysaliency/datasets/__init__.py index 4a277cd..23f2e26 100644 --- a/pysaliency/datasets/__init__.py +++ b/pysaliency/datasets/__init__.py @@ -123,8 +123,8 @@ def concatenate_datasets(stimuli, fixations): offset = sum(len(s) for s in stimuli[:i]) f = fixations[i].copy() f.n += offset - if isinstance(f, FixationTrains): - f.train_ns += offset + if isinstance(f, ScanpathFixations): + f.scanpaths.n += offset fixations[i] = f return concatenate_stimuli(stimuli), concatenate_fixations(fixations) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 9d2b9fe..049e3ea 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -174,4 +174,52 @@ def test_create_subset_numpy_mask(file_stimuli_with_attributes, fixation_trains) assert isinstance(sub_fixations, pysaliency.FixationTrains) assert len(sub_stimuli) == 2 - np.testing.assert_array_equal(sub_fixations.x, fixation_trains.x[np.isin(fixation_trains.n, [0, 2])]) \ No newline at end of file + np.testing.assert_array_equal(sub_fixations.x, fixation_trains.x[np.isin(fixation_trains.n, [0, 2])]) + + +def test_concatenate_datasets_with_scanpath_fixations(file_stimuli_with_attributes, scanpath_fixations): + stimuli1 = file_stimuli_with_attributes + stimuli2 = file_stimuli_with_attributes + + fixations1 = scanpath_fixations + fixations2 = scanpath_fixations + + stimuli_list = [stimuli1, stimuli2] + fixations_list = [fixations1, fixations2] + + concatenated_stimuli, concatenated_fixations = pysaliency.datasets.concatenate_datasets(stimuli_list, fixations_list) + + modified_fixations2 = fixations2.copy() + modified_fixations2.n += len(stimuli1) + modified_fixations2.scanpaths.n += len(stimuli1) + + expected_fixations = pysaliency.datasets.concatenate_fixations([fixations1, modified_fixations2]) + + + assert len(concatenated_stimuli) == len(stimuli1) + len(stimuli2) + + assert_scanpath_fixations_equal(concatenated_fixations, expected_fixations) + + +def test_concatenate_datasets_with_fixation_trains(file_stimuli_with_attributes, fixation_trains): + stimuli1 = file_stimuli_with_attributes + stimuli2 = file_stimuli_with_attributes + + fixations1 = fixation_trains + fixations2 = fixation_trains + + stimuli_list = [stimuli1, stimuli2] + fixations_list = [fixations1, fixations2] + + concatenated_stimuli, concatenated_fixations = pysaliency.datasets.concatenate_datasets(stimuli_list, fixations_list) + + modified_fixations2 = fixations2.copy() + modified_fixations2.n += len(stimuli1) + modified_fixations2.scanpaths.n += len(stimuli1) + + expected_fixations = pysaliency.datasets.concatenate_fixations([fixations1, modified_fixations2]) + + + assert len(concatenated_stimuli) == len(stimuli1) + len(stimuli2) + + assert_fixation_trains_equal(concatenated_fixations, expected_fixations) \ No newline at end of file