Skip to content

Commit

Permalink
Added functionality to call filter functions from config file (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
hkhanuja authored Nov 14, 2023
1 parent 4a846b8 commit b7f4b09
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 7 deletions.
10 changes: 9 additions & 1 deletion pysaliency/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
filter_stimuli_by_size,
train_split,
validation_split,
test_split
test_split,
filter_scanpaths_by_attribute,
filter_fixations_by_attribute,
filter_stimuli_by_attribute,
filter_scanpaths_by_length
)

from schema import Schema, Optional
Expand Down Expand Up @@ -42,6 +46,10 @@ def apply_dataset_filter_config(stimuli, fixations, filter_config):
'train_split': train_split,
'validation_split': validation_split,
'test_split': test_split,
'filter_scanpaths_by_attribute': add_stimuli_argument(filter_scanpaths_by_attribute),
'filter_fixations_by_attribute': add_stimuli_argument(filter_fixations_by_attribute),
'filter_stimuli_by_attribute': filter_stimuli_by_attribute,
'filter_scanpaths_by_length': add_stimuli_argument(filter_scanpaths_by_length)
}

if filter_config['type'] not in filter_dict:
Expand Down
2 changes: 1 addition & 1 deletion pysaliency/filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def filter_stimuli_by_attribute(stimuli: Stimuli, fixations: Fixations, attribut
return create_subset(stimuli, fixations, indices)


def filter_scanpaths_by_lengths(scanpaths: FixationTrains, intervals: list):
def filter_scanpaths_by_length(scanpaths: FixationTrains, intervals: list):
"""Filter Scanpaths by number of fixations"""

intervals = _check_intervals(intervals, type=int)
Expand Down
150 changes: 148 additions & 2 deletions tests/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import os

import pytest
import numpy as np
import pytest
from imageio import imwrite
from test_datasets import compare_fixations, compare_scanpaths
from test_filter_datasets import assert_stimuli_equal

import pysaliency
import pysaliency.dataset_config as dc
from pysaliency.filter_datasets import create_subset


@pytest.fixture
Expand All @@ -25,7 +29,51 @@ def fixation_trains():
[50, 500, 900]]
ns = [0, 0, 1]
subjects = [0, 1, 1]
return pysaliency.FixationTrains.from_fixation_trains(xs_trains, ys_trains, ts_trains, ns, subjects)
tasks = [0, 1, 0]
multi_dim_attribute = [[0.0, 1],[0, 3], [4, 5.5]]
durations_train = [
[42, 25, 100],
[99, 98],
[200, 150, 120]
]
some_attribute = np.arange(len(sum(xs_trains, [])))
return pysaliency.FixationTrains.from_fixation_trains(
xs_trains,
ys_trains,
ts_trains,
ns,
subjects,
attributes={'some_attribute': some_attribute},
scanpath_attributes={
'task': tasks,
'multi_dim_attribute': multi_dim_attribute
},
scanpath_fixation_attributes={'durations': durations_train},
scanpath_attribute_mapping={'durations': 'duration'},
)


@pytest.fixture
def file_stimuli_with_attributes(tmpdir):
filenames = []
for i in range(3):
filename = tmpdir.join('stimulus_{:04d}.png'.format(i))
imwrite(str(filename), np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8))
filenames.append(str(filename))

for sub_directory_index in range(3):
sub_directory = tmpdir.join('sub_directory_{:04d}'.format(sub_directory_index))
sub_directory.mkdir()
for i in range(5):
filename = sub_directory.join('stimulus_{:04d}.png'.format(i))
imwrite(str(filename), np.random.randint(low=0, high=255, size=(100, 100, 3), dtype=np.uint8))
filenames.append(str(filename))
attributes = {
'dva': list(range(len(filenames))),
'other_stuff': np.random.randn(len(filenames)),
'some_strings': list('abcdefghijklmnopqr'),
}
return pysaliency.FileStimuli(filenames=filenames, attributes=attributes)


@pytest.fixture
Expand Down Expand Up @@ -66,3 +114,101 @@ def test_load_dataset_with_filter(hdf5_dataset, stimuli, fixation_trains):
assert len(loaded_stimuli) == len(stimuli)
assert len(loaded_fixations.x) == 6
assert np.all(loaded_fixations.lengths < 2)


def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_task(stimuli, fixation_trains):
scanpaths = fixation_trains
filter_config = {
'type': 'filter_scanpaths_by_attribute',
'parameters': {
'attribute_name': 'task',
'attribute_value': 0,
'invert_match': False,
}
}
filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config)
inds = [0, 2]
expected_scanpaths = scanpaths.filter_fixation_trains(inds)
compare_scanpaths(filtered_scanpaths, expected_scanpaths)
assert_stimuli_equal(filtered_stimuli, stimuli)


def test_apply_dataset_filter_config_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(stimuli, fixation_trains):
scanpaths = fixation_trains
filter_config = {
'type': 'filter_scanpaths_by_attribute',
'parameters': {
'attribute_name': 'multi_dim_attribute',
'attribute_value': [0, 1],
'invert_match': True,
}
}
filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config)
inds = [1, 2]
expected_scanpaths = scanpaths.filter_fixation_trains(inds)
compare_scanpaths(filtered_scanpaths, expected_scanpaths)
assert_stimuli_equal(filtered_stimuli, stimuli)


def test_apply_dataset_filter_config_filter_fixations_by_attribute_subject_invert_match(stimuli, fixation_trains):
fixations = fixation_trains[:]
filter_config = {
'type': 'filter_fixations_by_attribute',
'parameters': {
'attribute_name': 'subjects',
'attribute_value': 0,
'invert_match': True,
}
}
filtered_stimuli, filtered_fixations = dc.apply_dataset_filter_config(stimuli, fixations, filter_config)
inds = [3, 4, 5, 6, 7]
expected_fixations = fixations[inds]
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, stimuli)


def test_apply_dataset_filter_config_filter_stimuli_by_attribute_dva(file_stimuli_with_attributes, fixation_trains):
fixations = fixation_trains[:]
filter_config = {
'type': 'filter_stimuli_by_attribute',
'parameters': {
'attribute_name': 'dva',
'attribute_value': 1,
'invert_match': False,
}
}
filtered_stimuli, filtered_fixations = dc.apply_dataset_filter_config(file_stimuli_with_attributes, fixations, filter_config)
inds = [1]
expected_stimuli, expected_fixations = create_subset(file_stimuli_with_attributes, fixations, inds)
compare_fixations(filtered_fixations, expected_fixations)
assert_stimuli_equal(filtered_stimuli, expected_stimuli)


def test_apply_dataset_filter_config_filter_scanpaths_by_length_multiple_inputs(stimuli, fixation_trains):
scanpaths = fixation_trains
filter_config = {
'type': 'filter_scanpaths_by_length',
'parameters': {
'intervals': [(1, 2), (2, 3)]
}
}
filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config)
inds = [1]
expected_scanpaths = scanpaths.filter_fixation_trains(inds)
compare_scanpaths(filtered_scanpaths, expected_scanpaths)
assert_stimuli_equal(filtered_stimuli, stimuli)


def test_apply_dataset_filter_config_filter_scanpaths_by_length_single_input(stimuli, fixation_trains):
scanpaths = fixation_trains
filter_config = {
'type': 'filter_scanpaths_by_length',
'parameters': {
'intervals': [(3)]
}
}
filtered_stimuli, filtered_scanpaths = dc.apply_dataset_filter_config(stimuli, scanpaths, filter_config)
inds = [0, 2]
expected_scanpaths = scanpaths.filter_fixation_trains(inds)
compare_scanpaths(filtered_scanpaths, expected_scanpaths)
assert_stimuli_equal(filtered_stimuli, stimuli)
6 changes: 3 additions & 3 deletions tests/test_filter_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pysaliency
import pysaliency.filter_datasets as filter_datasets
from pysaliency.filter_datasets import filter_fixations_by_attribute, filter_stimuli_by_attribute, filter_scanpaths_by_attribute, filter_scanpaths_by_lengths, create_subset
from pysaliency.filter_datasets import filter_fixations_by_attribute, filter_stimuli_by_attribute, filter_scanpaths_by_attribute, filter_scanpaths_by_length, create_subset
from test_datasets import compare_fixations, compare_scanpaths


Expand Down Expand Up @@ -434,9 +434,9 @@ def test_filter_scanpaths_by_attribute_multi_dim_attribute_invert_match(fixation


@pytest.mark.parametrize('intervals', [([(1, 2), (2, 3)]), ([(2, 3), (3, 4)]), ([(2)]), ([(3)])])
def test_filter_scanpaths_by_lengths(fixation_trains, intervals):
def test_filter_scanpaths_by_length(fixation_trains, intervals):
scanpaths = fixation_trains
filtered_scanpaths = filter_scanpaths_by_lengths(scanpaths, intervals)
filtered_scanpaths = filter_scanpaths_by_length(scanpaths, intervals)
if intervals == [(1, 2), (2, 3)]:
inds = [1]
expected_scanpaths = scanpaths.filter_fixation_trains(inds)
Expand Down

0 comments on commit b7f4b09

Please sign in to comment.