diff --git a/pysaliency/filter_datasets.py b/pysaliency/filter_datasets.py index b674bd6..5ab15ad 100644 --- a/pysaliency/filter_datasets.py +++ b/pysaliency/filter_datasets.py @@ -260,10 +260,16 @@ def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribut return fixations[mask] -def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False): - """Filter stimuli by values of attribute (stimuli.attributes)""" +def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value=None, attribute_values=None, invert_match=False): + """Filter stimuli by values of attribute (stimuli.attributes) - mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value + use `attribute_value` to filter for a single value, or `attribute_values` to filter for multiple allowed values + """ + + if attribute_values is not None: + mask = np.isin(np.asarray(stimuli.attributes[attribute_name]), attribute_values) + else: + mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value if mask.ndim > 1: mask = np.all(mask, axis=1) diff --git a/tests/test_filter_datasets.py b/tests/test_filter_datasets.py index afe4d8d..40234cb 100644 --- a/tests/test_filter_datasets.py +++ b/tests/test_filter_datasets.py @@ -345,22 +345,32 @@ def test_stratified_crossval_splits_multiple_attributes(many_stimuli, crossval_f def test_filter_stimuli_by_attribute_dva(file_stimuli_with_attributes, fixation_trains): fixations = fixation_trains[:] - attribute_name = 'dva' + attribute_name = 'dva' attribute_value = 1 - invert_match = False - filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match) + filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value) inds = [1] expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) compare_fixations(filtered_fixations, expected_fixations) assert_stimuli_equal(filtered_stimuli, expected_stimuli) +def test_filter_stimuli_by_attribute_multiple_values(file_stimuli_with_attributes, fixation_trains): + fixations = fixation_trains[:] + attribute_name = 'dva' + attribute_values = [1, 2] + filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_values=attribute_values) + inds = [1, 2] + expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) + compare_fixations(filtered_fixations, expected_fixations) + assert_stimuli_equal(filtered_stimuli, expected_stimuli) + + def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with_attributes, fixation_trains): fixations = fixation_trains[:] - attribute_name = 'some_strings' + attribute_name = 'some_strings' attribute_value = 'n' invert_match = True - filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match) + filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match=invert_match) inds = list(range(0, 13)) + list(range(14, 18)) expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds) compare_fixations(filtered_fixations, expected_fixations) @@ -369,7 +379,7 @@ def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains): fixations = fixation_trains[:] - attribute_name = 'subjects' + attribute_name = 'subjects' attribute_value = 0 invert_match = True filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) @@ -380,7 +390,7 @@ def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains): def test_filter_fixations_by_attribute_some_attribute(fixation_trains): fixations = fixation_trains[:] - attribute_name = 'some_attribute' + attribute_name = 'some_attribute' attribute_value = 2 invert_match = False filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) @@ -391,7 +401,7 @@ def test_filter_fixations_by_attribute_some_attribute(fixation_trains): def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trains): fixations = fixation_trains[:] - attribute_name = 'some_attribute' + attribute_name = 'some_attribute' attribute_value = 3 invert_match = True filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match) @@ -402,7 +412,7 @@ def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trai def test_filter_scanpaths_by_attribute_task(fixation_trains): scanpaths = fixation_trains - attribute_name = 'task' + attribute_name = 'task' attribute_value = 0 invert_match = False filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) @@ -413,7 +423,7 @@ def test_filter_scanpaths_by_attribute_task(fixation_trains): def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): scanpaths = fixation_trains - attribute_name = 'multi_dim_attribute' + attribute_name = 'multi_dim_attribute' attribute_value = [0, 3] invert_match = False filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match) @@ -424,7 +434,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains): def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation_trains): scanpaths = fixation_trains - attribute_name = 'multi_dim_attribute' + attribute_name = 'multi_dim_attribute' attribute_value = [0, 1] invert_match = True filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)