Skip to content

Commit

Permalink
allow filtering stimuli by multiple attribute values
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmmerer <[email protected]>
  • Loading branch information
matthias-k committed Dec 16, 2023
1 parent 16e212d commit 4cbad25
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
12 changes: 9 additions & 3 deletions pysaliency/filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,16 @@ def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribut
return fixations[mask]


def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter stimuli by values of attribute (stimuli.attributes)"""
def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value=None, attribute_values=None, invert_match=False):
"""Filter stimuli by values of attribute (stimuli.attributes)
mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value
use `attribute_value` to filter for a single value, or `attribute_values` to filter for multiple allowed values
"""

if attribute_values is not None:
mask = np.isin(np.asarray(stimuli.attributes[attribute_name]), attribute_values)
else:
mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value
if mask.ndim > 1:
mask = np.all(mask, axis=1)

Expand Down
32 changes: 21 additions & 11 deletions tests/test_filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,22 +345,32 @@ def test_stratified_crossval_splits_multiple_attributes(many_stimuli, crossval_f

def test_filter_stimuli_by_attribute_dva(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'dva'
attribute_name = 'dva'
attribute_value = 1
invert_match = False
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match)
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value)
inds = [1]
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, expected_stimuli)


def test_filter_stimuli_by_attribute_multiple_values(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'dva'
attribute_values = [1, 2]
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_values=attribute_values)
inds = [1, 2]
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, expected_stimuli)


def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_strings'
attribute_name = 'some_strings'
attribute_value = 'n'
invert_match = True
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match)
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match=invert_match)
inds = list(range(0, 13)) + list(range(14, 18))
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
Expand All @@ -369,7 +379,7 @@ def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with

def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'subjects'
attribute_name = 'subjects'
attribute_value = 0
invert_match = True
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -380,7 +390,7 @@ def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains):

def test_filter_fixations_by_attribute_some_attribute(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_attribute'
attribute_name = 'some_attribute'
attribute_value = 2
invert_match = False
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -391,7 +401,7 @@ def test_filter_fixations_by_attribute_some_attribute(fixation_trains):

def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_attribute'
attribute_name = 'some_attribute'
attribute_value = 3
invert_match = True
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -402,7 +412,7 @@ def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trai

def test_filter_scanpaths_by_attribute_task(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'task'
attribute_name = 'task'
attribute_value = 0
invert_match = False
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand All @@ -413,7 +423,7 @@ def test_filter_scanpaths_by_attribute_task(fixation_trains):

def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'multi_dim_attribute'
attribute_name = 'multi_dim_attribute'
attribute_value = [0, 3]
invert_match = False
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand All @@ -424,7 +434,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains):

def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'multi_dim_attribute'
attribute_name = 'multi_dim_attribute'
attribute_value = [0, 1]
invert_match = True
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand Down

0 comments on commit 4cbad25

Please sign in to comment.