From d37f2e6a38b70c5b0d11d6654930eb7de28ea92e Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Mon, 6 Nov 2023 19:10:09 +0100 Subject: [PATCH] Updated test_datasets.py Added tests for filtering stimuli and fixations --- tests/test_datasets.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 0b71deb..8b39af9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -747,5 +747,39 @@ def test_scanpaths_from_fixations(fixation_indices): compare_fixations(sub_fixations, new_sub_fixations, crop_length=True) +invert_param_values = [True, False] +@pytest.mark.parametrize('attribute_name , attribute_value', [('dva',1), ('dva',4), ('dva',15), ('some_strings','a'), ('some_strings','n'), ('some_strings','q')]) +@pytest.mark.parametrize('invert_match', invert_param_values) +def test_filter_stimuli_by_attribute(file_stimuli_with_attributes, fixation_trains, attribute_name, attribute_value, invert_match): + + fixations = fixation_trains[np.arange(len(fixation_trains))] + # Access the attribute using the attributes dictionary + attribute_data = file_stimuli_with_attributes.attributes[attribute_name] + mask = np.array([element == attribute_value for element in attribute_data]) + # mask = [element == attribute_value for element in getattr(file_stimuli_with_attributes, attribute_name)] not sure about this line because in real mit stimuli, stimuli.attributes returns an empty dict and this works + + if invert_match is True: + mask = ~mask + stimulus_indices = list(np.nonzero(mask)[0]) + sub_stimuli, sub_fixations = pysaliency.datasets.create_subset(file_stimuli_with_attributes, fixations, stimulus_indices) + + assert not isinstance(sub_fixations, pysaliency.FixationTrains) + assert len(sub_stimuli) == len(stimulus_indices) + np.testing.assert_array_equal(sub_fixations.x, fixations.x[np.isin(fixations.n, stimulus_indices)]) + + +invert_param_values = [True, False] +@pytest.mark.parametrize('attribute_name , attribute_value', [('subjects',0), ('subjects',1), ('subjects',100), ('x',1), ('x',19), ('y',10), ('y',12), ('t',100), ('t',500), ('t',10000)]) +@pytest.mark.parametrize('invert_match', invert_param_values) +def test_filter_fixations_by_attribute(fixation_trains, attribute_name, attribute_value, invert_match): + fixations = fixation_trains[:] + mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)]) + if invert_match is True: + mask = ~mask + inds = list(np.nonzero(mask)[0]) + _f = fixations.filter(inds) + compare_fixations_subset(_f, fixations, inds) + + if __name__ == '__main__': unittest.main()