Skip to content

Commit

Permalink
Update dataset_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hkhanuja authored Nov 6, 2023
1 parent d37f2e6 commit 15ca142
Showing 1 changed file with 2 additions and 50 deletions.
52 changes: 2 additions & 50 deletions pysaliency/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np
from .datasets import read_hdf5, clip_out_of_stimulus_fixations, remove_out_of_stimulus_fixations, FixationTrains, Fixations, Stimuli, create_subset
from .datasets import read_hdf5, clip_out_of_stimulus_fixations, remove_out_of_stimulus_fixations
from .filter_datasets import (
filter_fixations_by_number,
filter_stimuli_by_number,
filter_stimuli_by_size,
train_split,
validation_split,
test_split,
_check_intervals
)

from schema import Schema, Optional
Expand Down Expand Up @@ -53,53 +51,7 @@ def apply_dataset_filter_config(stimuli, fixations, filter_config):

return filter_fn(stimuli, fixations, **filter_config['parameters'])


def filter_scanpaths_by_attribute(scanpaths: FixationTrains, attribute_name, attribute_value, invert_match=False):
"""Filter Scanpaths by values of scanpath attribute (fixation_trains.scanpath_attributes)"""

mask = np.array([element == attribute_value for element in getattr(scanpaths, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return scanpaths.filter_fixation_trains(indices)


def filter_fixations_by_attribute(fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter Fixations by values of attribute (fixations.__attributes__)"""

mask = np.array([element == attribute_value for element in getattr(fixations, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return fixations.filter(indices)


def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribute_name, attribute_value, invert_match=False):
"""Filter stimuli by values of attribute"""

mask = np.array([element == attribute_value for element in getattr(stimuli, attribute_name)])
if invert_match is True:
mask = ~mask
indices = list(np.nonzero(mask)[0])
return create_subset(stimuli, fixations, indices)


def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list):
"""Filter Scanpaths by number of fixations"""

intervals = _check_intervals(intervals, type=int)
mask = np.zeros(len(scanpaths.train_lengths), dtype=bool)

for n1, n2 in intervals:
temp_mask = np.logical_and(scanpaths.train_lengths>=n1,scanpaths.train_lengths<=n2)
mask = np.logical_or(mask, temp_mask)

indices = list(np.nonzero(mask)[0])
scanpaths = scanpaths.filter_fixation_trains(indices)

return scanpaths



def _clip_out_of_stimulus_fixations(stimuli, fixations):
clipped_fixations = clip_out_of_stimulus_fixations(fixations, stimuli=stimuli)
return stimuli, clipped_fixations
Expand Down

0 comments on commit 15ca142

Please sign in to comment.