diff --git a/pysaliency/filter_datasets.py b/pysaliency/filter_datasets.py index eef963c..429a58d 100644 --- a/pysaliency/filter_datasets.py +++ b/pysaliency/filter_datasets.py @@ -4,7 +4,7 @@ from boltons.iterutils import chunked -from .datasets import create_subset +from .datasets import create_subset, FixationTrains, Fixations, Stimuli def train_split(stimuli, fixations, crossval_folds, fold_no, val_folds=1, test_folds=1, random=True, stratified_attributes=None): @@ -232,3 +232,50 @@ def filter_stimuli_by_size(stimuli, fixations, size=None, sizes=None): indices = [i for i in range(len(stimuli)) if stimuli.sizes[i] in sizes] return create_subset(stimuli, fixations, indices) + + +def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False): + """Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)""" + + mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)]) + if invert_match is True: + mask = ~mask + indices = list(np.nonzero(mask)[0]) + return scanpaths.filter_fixation_trains(indices) + + +def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False): + """Filter Fixations by values of attribute (fixations.__attributes__)""" + + mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)]) + if invert_match is True: + mask = ~mask + indices = list(np.nonzero(mask)[0]) + return fixations.filter(indices) + + +def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False): + """Filter stimuli by values of attribute""" + + mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)]) + if invert_match is True: + mask = ~mask + indices = list(np.nonzero(mask)[0]) + return create_subset(stimuli, fixations, indices) + + +def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list): + """Filter Scanpaths by number of fixations""" + + intervals = _check_intervals(intervals, type=int) + mask = np.zeros(len(scanpaths.train_lengths), dtype=bool) + + for n1, n2 in intervals: + temp_mask = np.logical_and(scanpaths.train_lengths>=n1,scanpaths.train_lengths<=n2) + mask = np.logical_or(mask, temp_mask) + + indices = list(np.nonzero(mask)[0]) + scanpaths = scanpaths.filter_fixation_trains(indices) + + return scanpaths +