Skip to content

Commit

Permalink
Updated filter_datasets.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hkhanuja authored Nov 8, 2023
1 parent a9efdab commit 2072e21
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions pysaliency/filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,31 +236,37 @@ def filter_stimuli_by_size(stimuli, fixations, size=None, sizes=None):

def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False):
"""Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)"""
mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)])

mask = scanpaths.scanpath_attributes[attribute_name] == attribute_value
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return scanpaths.filter_fixation_trains(indices)
if mask.ndim>1:
mask = np.all(mask, axis=1)

return scanpaths.filter_fixation_trains(mask)


def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter Fixations by values of attribute (fixations.__attributes__)"""

mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)])
mask = np.asarray(getattr(fixations, attribute_name)) == attribute_value
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return fixations.filter(indices)

if mask.ndim>1:
mask = np.all(mask, axis=1)

return fixations[mask]

def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter stimuli by values of attribute"""
"""Filter stimuli by values of attribute (stimuli.attributes)"""

mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)])
mask = np.asarray(stimuli.attributes[attribute_name]) == attribute_value
if invert_match is True:
mask = ~mask
if mask.ndim>1:
mask = np.all(mask, axis=1)
indices = list(np.nonzero(mask)[0])

return create_subset(stimuli, fixations, indices)


Expand All @@ -269,13 +275,12 @@ def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list):

intervals = _check_intervals(intervals, type=int)
mask = np.zeros(len(scanpaths.train_lengths), dtype=bool)

for n1, n2 in intervals:
temp_mask = np.logical_and(scanpaths.train_lengths>=n1,scanpaths.train_lengths<=n2)
mask = np.logical_or(mask, temp_mask)

indices = list(np.nonzero(mask)[0])

scanpaths = scanpaths.filter_fixation_trains(indices)

return scanpaths

0 comments on commit 2072e21

Please sign in to comment.