Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow filtering stimuli by multiple attribute values #43

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pysaliency/filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,16 @@ def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribut
return fixations[mask]


def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter stimuli by values of attribute (stimuli.attributes)"""
def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value=None, attribute_values=None, invert_match=False):
"""Filter stimuli by values of attribute (stimuli.attributes)
mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value
use `attribute_value` to filter for a single value, or `attribute_values` to filter for multiple allowed values
"""

if attribute_values is not None:
mask = np.isin(np.asarray(stimuli.attributes[attribute_name]), attribute_values)
else:
mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value
if mask.ndim > 1:
mask = np.all(mask, axis=1)

Expand Down
32 changes: 21 additions & 11 deletions tests/test_filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,22 +345,32 @@ def test_stratified_crossval_splits_multiple_attributes(many_stimuli, crossval_f

def test_filter_stimuli_by_attribute_dva(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'dva'
attribute_name = 'dva'
attribute_value = 1
invert_match = False
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match)
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value)
inds = [1]
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, expected_stimuli)


def test_filter_stimuli_by_attribute_multiple_values(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'dva'
attribute_values = [1, 2]
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_values=attribute_values)
inds = [1, 2]
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, expected_stimuli)


def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_strings'
attribute_name = 'some_strings'
attribute_value = 'n'
invert_match = True
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match)
filtered_stimuli, filtered_fixations = filter_stimuli_by_attribute(file_stimuli_with_attributes, fixations, attribute_name, attribute_value, invert_match=invert_match)
inds = list(range(0, 13)) + list(range(14, 18))
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
Expand All @@ -369,7 +379,7 @@ def test_filter_stimuli_by_attribute_some_strings_invert_match(file_stimuli_with

def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'subjects'
attribute_name = 'subjects'
attribute_value = 0
invert_match = True
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -380,7 +390,7 @@ def test_filter_fixations_by_attribute_subject_invert_match(fixation_trains):

def test_filter_fixations_by_attribute_some_attribute(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_attribute'
attribute_name = 'some_attribute'
attribute_value = 2
invert_match = False
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -391,7 +401,7 @@ def test_filter_fixations_by_attribute_some_attribute(fixation_trains):

def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trains):
fixations = fixation_trains[:]
attribute_name = 'some_attribute'
attribute_name = 'some_attribute'
attribute_value = 3
invert_match = True
filtered_fixations = filter_fixations_by_attribute(fixations, attribute_name, attribute_value, invert_match)
Expand All @@ -402,7 +412,7 @@ def test_filter_fixations_by_attribute_some_attribute_invert_match(fixation_trai

def test_filter_scanpaths_by_attribute_task(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'task'
attribute_name = 'task'
attribute_value = 0
invert_match = False
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand All @@ -413,7 +423,7 @@ def test_filter_scanpaths_by_attribute_task(fixation_trains):

def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'multi_dim_attribute'
attribute_name = 'multi_dim_attribute'
attribute_value = [0, 3]
invert_match = False
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand All @@ -424,7 +434,7 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute(fixation_trains):

def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation_trains):
scanpaths = fixation_trains
attribute_name = 'multi_dim_attribute'
attribute_name = 'multi_dim_attribute'
attribute_value = [0, 1]
invert_match = True
filtered_scanpaths = filter_scanpaths_by_attribute(scanpaths, attribute_name, attribute_value, invert_match)
Expand Down
Loading