Skip to content

Commit

Permalink
Update dataset_config.py
Browse files Browse the repository at this point in the history
Changed the filter functions according to discussion
  • Loading branch information
hkhanuja authored Nov 6, 2023
1 parent 519f4af commit d0bfd6e
Showing 1 changed file with 18 additions and 81 deletions.
99 changes: 18 additions & 81 deletions pysaliency/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,97 +54,34 @@ def apply_dataset_filter_config(stimuli, fixations, filter_config):
return filter_fn(stimuli, fixations, **filter_config['parameters'])


def filter_scanpaths_by_attribute(scanpaths: FixationTrains, whitelist: dict=None, blacklist: dict=None):
"""Filter Scanpaths by values of scanpath attributes (fixation_trains.scanpath_attributes), the dictionary can have only one attribute"""
def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False):
"""Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)"""

assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None)
if whitelist is not None:
assert(len(whitelist)==1)
if blacklist is not None:
assert(len(blacklist)==1)

if whitelist is not None:
attribute_name = list(whitelist.keys())[0]
attribute_value = list(whitelist.values())[0]

mask = np.zeros(len(getattr(scanpaths, attribute_name)), dtype=bool)

mask = np.logical_or(mask,[element == attribute_value for element in getattr(scanpaths, attribute_name)])
indices = list(np.nonzero(mask)[0])
return scanpaths.filter_fixation_trains(indices)

if blacklist is not None:
attribute_name = list(blacklist.keys())[0]
attribute_value = list(blacklist.values())[0]

mask = np.zeros(len(getattr(scanpaths, attribute_name)), dtype=bool)

mask = np.logical_or(mask,[element == attribute_value for element in getattr(scanpaths, attribute_name)])
mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return scanpaths.filter_fixation_trains(indices)


def filter_fixations_by_attribute(fixations: Fixations, whitelist: dict=None, blacklist: dict=None):
"""Filter Fixations by values of attributes (fixations.__attributes__), the dictionary can have only one attribute"""

assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None)
if whitelist is not None:
assert(len(whitelist)==1)
if blacklist is not None:
assert(len(blacklist)==1)

if whitelist is not None:
attribute_name = list(whitelist.keys())[0]
attribute_value = list(whitelist.values())[0]

mask = np.zeros(len(getattr(fixations, attribute_name)), dtype=bool)
indices = list(np.nonzero(mask)[0])
return scanpaths.filter_fixation_trains(indices)

mask = np.logical_or(mask,[element == attribute_value for element in getattr(fixations, attribute_name)])
indices = list(np.nonzero(mask)[0])
return fixations.filter(indices)

if blacklist is not None:
attribute_name = list(blacklist.keys())[0]
attribute_value = list(blacklist.values())[0]

mask = np.zeros(len(getattr(fixations, attribute_name)), dtype=bool)
def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter Fixations by values of attribute (fixations.__attributes__)"""

mask = np.logical_or(mask,[element == attribute_value for element in getattr(fixations, attribute_name)])
mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return fixations.filter(indices)


def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, whitelist: dict=None, blacklist: dict=None):
"""Filter stimuli by values of attribute, the dictionary can have only one attribute"""

assert (whitelist is None and blacklist is not None) or (whitelist is not None and blacklist is None)
if whitelist is not None:
assert(len(whitelist)==1)
if blacklist is not None:
assert(len(blacklist)==1)

if whitelist is not None:
attribute_name = list(whitelist.keys())[0]
attribute_value = list(whitelist.values())[0]
indices = list(np.nonzero(mask)[0])
return fixations.filter(indices)

mask = np.zeros(len(stimuli), dtype=bool)

mask = np.logical_or(mask,[element == attribute_value for element in getattr(stimuli, attribute_name)])
indices = list(np.nonzero(mask)[0])
return create_subset(stimuli, fixations, indices)

if blacklist is not None:
attribute_name = list(blacklist.keys())[0]
attribute_value = list(blacklist.values())[0]
def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter stimuli by values of attribute"""

mask = np.zeros(len(stimuli), dtype=bool)

mask = np.logical_or(mask,[element == attribute_value for element in getattr(stimuli, attribute_name)])
mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return create_subset(stimuli, fixations, indices)
indices = list(np.nonzero(mask)[0])
return create_subset(stimuli, fixations, indices)


def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list):
Expand Down

0 comments on commit d0bfd6e

Please sign in to comment.