From d0bfd6e6c218dc2661cbbf362de47abfbef4c38f Mon Sep 17 00:00:00 2001 From: Harneet Singh Khanuja Date: Mon, 6 Nov 2023 19:08:36 +0100 Subject: [PATCH] Update dataset_config.py Changed the filter functions according to discussion --- pysaliency/dataset_config.py | 99 +++++++----------------------------- 1 file changed, 18 insertions(+), 81 deletions(-) diff --git a/pysaliency/dataset_config.py b/pysaliency/dataset_config.py index 75a7c31..5e6f673 100644 --- a/pysaliency/dataset_config.py +++ b/pysaliency/dataset_config.py @@ -54,97 +54,34 @@ def apply_dataset_filter_config(stimuli, fixations, filter_config): return filter_fn(stimuli, fixations, **filter_config['parameters']) -def filter_scanpaths_by_attribute(scanpaths: FixationTrains, whitelist: dict=None, blacklist: dict=None): - """Filter Scanpaths by values of scanpath attributes (fixation_trains.scanpath_attributes), the dictionary can have only one attribute""" +def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False): + """Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)""" - assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None) - if whitelist is not None: - assert(len(whitelist)==1) - if blacklist is not None: - assert(len(blacklist)==1) - - if whitelist is not None: - attribute_name = list(whitelist.keys())[0] - attribute_value = list(whitelist.values())[0] - - mask = np.zeros(len(getattr(scanpaths, attribute_name)), dtype=bool) - - mask = np.logical_or(mask,[element == attribute_value for element in getattr(scanpaths, attribute_name)]) - indices = list(np.nonzero(mask)[0]) - return scanpaths.filter_fixation_trains(indices) - - if blacklist is not None: - attribute_name = list(blacklist.keys())[0] - attribute_value = list(blacklist.values())[0] - - mask = np.zeros(len(getattr(scanpaths, attribute_name)), dtype=bool) - - mask = np.logical_or(mask,[element == attribute_value for element in getattr(scanpaths, attribute_name)]) + mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)]) + if invert_match is True: mask = ~mask - indices = list(np.nonzero(mask)[0]) - return scanpaths.filter_fixation_trains(indices) - - -def filter_fixations_by_attribute(fixations: Fixations, whitelist: dict=None, blacklist: dict=None): - """Filter Fixations by values of attributes (fixations.__attributes__), the dictionary can have only one attribute""" - - assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None) - if whitelist is not None: - assert(len(whitelist)==1) - if blacklist is not None: - assert(len(blacklist)==1) - - if whitelist is not None: - attribute_name = list(whitelist.keys())[0] - attribute_value = list(whitelist.values())[0] - - mask = np.zeros(len(getattr(fixations, attribute_name)), dtype=bool) + indices = list(np.nonzero(mask)[0]) + return scanpaths.filter_fixation_trains(indices) - mask = np.logical_or(mask,[element == attribute_value for element in getattr(fixations, attribute_name)]) - indices = list(np.nonzero(mask)[0]) - return fixations.filter(indices) - - if blacklist is not None: - attribute_name = list(blacklist.keys())[0] - attribute_value = list(blacklist.values())[0] - mask = np.zeros(len(getattr(fixations, attribute_name)), dtype=bool) +def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False): + """Filter Fixations by values of attribute (fixations.__attributes__)""" - mask = np.logical_or(mask,[element == attribute_value for element in getattr(fixations, attribute_name)]) + mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)]) + if invert_match is True: mask = ~mask - indices = list(np.nonzero(mask)[0]) - return fixations.filter(indices) - - -def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, whitelist: dict=None, blacklist: dict=None): - """Filter stimuli by values of attribute, the dictionary can have only one attribute""" - - assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None) - if whitelist is not None: - assert(len(whitelist)==1) - if blacklist is not None: - assert(len(blacklist)==1) - - if whitelist is not None: - attribute_name = list(whitelist.keys())[0] - attribute_value = list(whitelist.values())[0] + indices = list(np.nonzero(mask)[0]) + return fixations.filter(indices) - mask = np.zeros(len(stimuli), dtype=bool) - mask = np.logical_or(mask,[element == attribute_value for element in getattr(stimuli, attribute_name)]) - indices = list(np.nonzero(mask)[0]) - return create_subset(stimuli, fixations, indices) - - if blacklist is not None: - attribute_name = list(blacklist.keys())[0] - attribute_value = list(blacklist.values())[0] +def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False): + """Filter stimuli by values of attribute""" - mask = np.zeros(len(stimuli), dtype=bool) - - mask = np.logical_or(mask,[element == attribute_value for element in getattr(stimuli, attribute_name)]) + mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)]) + if invert_match is True: mask = ~mask - indices = list(np.nonzero(mask)[0]) - return create_subset(stimuli, fixations, indices) + indices = list(np.nonzero(mask)[0]) + return create_subset(stimuli, fixations, indices) def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list):