Skip to content

Commit

Permalink
Updated test_datasets.py
Browse files Browse the repository at this point in the history
Added tests for filtering stimuli and fixations
  • Loading branch information
hkhanuja authored Nov 6, 2023
1 parent d0bfd6e commit d37f2e6
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,5 +747,39 @@ def test_scanpaths_from_fixations(fixation_indices):
compare_fixations(sub_fixations, new_sub_fixations, crop_length=True)


invert_param_values = [True, False]
@pytest.mark.parametrize('attribute_name , attribute_value', [('dva',1), ('dva',4), ('dva',15), ('some_strings','a'), ('some_strings','n'), ('some_strings','q')])
@pytest.mark.parametrize('invert_match', invert_param_values)
def test_filter_stimuli_by_attribute(file_stimuli_with_attributes, fixation_trains, attribute_name, attribute_value, invert_match):

fixations = fixation_trains[np.arange(len(fixation_trains))]
# Access the attribute using the attributes dictionary
attribute_data = file_stimuli_with_attributes.attributes[attribute_name]
mask = np.array([element == attribute_value for element in attribute_data])
# mask = [element == attribute_value for element in getattr(file_stimuli_with_attributes, attribute_name)] not sure about this line because in real mit stimuli, stimuli.attributes returns an empty dict and this works

if invert_match is True:
mask = ~mask
stimulus_indices = list(np.nonzero(mask)[0])
sub_stimuli, sub_fixations = pysaliency.datasets.create_subset(file_stimuli_with_attributes, fixations, stimulus_indices)

assert not isinstance(sub_fixations, pysaliency.FixationTrains)
assert len(sub_stimuli) == len(stimulus_indices)
np.testing.assert_array_equal(sub_fixations.x, fixations.x[np.isin(fixations.n, stimulus_indices)])


invert_param_values = [True, False]
@pytest.mark.parametrize('attribute_name , attribute_value', [('subjects',0), ('subjects',1), ('subjects',100), ('x',1), ('x',19), ('y',10), ('y',12), ('t',100), ('t',500), ('t',10000)])
@pytest.mark.parametrize('invert_match', invert_param_values)
def test_filter_fixations_by_attribute(fixation_trains, attribute_name, attribute_value, invert_match):
fixations = fixation_trains[:]
mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)])
if invert_match is True:
mask = ~mask
inds = list(np.nonzero(mask)[0])
_f = fixations.filter(inds)
compare_fixations_subset(_f, fixations, inds)


if __name__ == '__main__':
unittest.main()

0 comments on commit d37f2e6

Please sign in to comment.