diff --git a/pysaliency/filter_datasets.py b/pysaliency/filter_datasets.py index 429a58d..8b369b0 100644 --- a/pysaliency/filter_datasets.py +++ b/pysaliency/filter_datasets.py @@ -236,31 +236,37 @@ def filter_stimuli_by_size(stimuli, fixations, size=None, sizes=None): def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False): """Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)""" - - mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)]) + + mask = scanpaths.scanpath_attributes[attribute_name] == attribute_value if invert_match is True: mask = ~mask - indices = list(np.nonzero(mask)[0]) - return scanpaths.filter_fixation_trains(indices) + if mask.ndim>1: + mask = np.all(mask, axis=1) + + return scanpaths.filter_fixation_trains(mask) def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False): """Filter Fixations by values of attribute (fixations.__attributes__)""" - mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)]) + mask = np.asarray(getattr(fixations, attribute_name)) == attribute_value if invert_match is True: mask = ~mask - indices = list(np.nonzero(mask)[0]) - return fixations.filter(indices) - + if mask.ndim>1: + mask = np.all(mask, axis=1) + + return fixations[mask] def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False): - """Filter stimuli by values of attribute""" + """Filter stimuli by values of attribute (stimuli.attributes)""" - mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)]) + mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value if invert_match is True: mask = ~mask + if mask.ndim>1: + mask = np.all(mask, axis=1) indices = list(np.nonzero(mask)[0]) + return create_subset(stimuli, fixations, indices) @@ -269,13 +275,12 @@ def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list): intervals = _check_intervals(intervals, type=int) mask = np.zeros(len(scanpaths.train_lengths), dtype=bool) - for n1, n2 in intervals: temp_mask = np.logical_and(scanpaths.train_lengths>=n1,scanpaths.train_lengths<=n2) mask = np.logical_or(mask, temp_mask) - indices = list(np.nonzero(mask)[0]) + scanpaths = scanpaths.filter_fixation_trains(indices) - + return scanpaths