Skip to content

Commit

Permalink
upgrade FixationTrains.train_xs etc to VariableLengthArrays
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Mar 29, 2024
1 parent 6fb8c06 commit 0ff3421
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 45 deletions.
53 changes: 30 additions & 23 deletions pysaliency/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,35 +480,42 @@ class FixationTrains(Fixations):
def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanpath_attributes=None, scanpath_fixation_attributes=None, attributes=None, scanpath_attribute_mapping=None):
self.__attributes__ = list(self.__attributes__)
self.__attributes__.append('scanpath_index')
self.train_xs = train_xs
self.train_ys = train_ys
self.train_ts = train_ts

if not isinstance(train_xs, VariableLengthArray):
self.train_lengths = np.array([len(remove_trailing_nans(train_x)) for train_x in train_xs])
else:
self.train_lengths = train_xs.lengths.copy()

self.train_xs = self._as_variable_length_scanpath_array(train_xs)
self.train_ys = self._as_variable_length_scanpath_array(train_ys)
self.train_ts = self._as_variable_length_scanpath_array(train_ts)
self.train_ns = train_ns
self.train_subjects = train_subjects
N_trains = self.train_xs.shape[0] * self.train_xs.shape[1] - np.isnan(self.train_xs).sum()
max_length_trains = self.train_xs.shape[1]
N_fixations = self.train_lengths.sum()
max_length_trains = self.train_lengths.max() if len(self.train_lengths) else 0
max_history_length = max(max_length_trains - 1, 0)


# Create conditional fixations
self.x = np.empty(N_trains)
self.y = np.empty(N_trains)
self.t = np.empty(N_trains)
self.x_hist = np.empty((N_trains, max_length_trains - 1))
self.y_hist = np.empty((N_trains, max_length_trains - 1))
self.t_hist = np.empty((N_trains, max_length_trains - 1))
self.x = np.empty(N_fixations)
self.y = np.empty(N_fixations)
self.t = np.empty(N_fixations)
self.x_hist = np.empty((N_fixations, max_length_trains))
self.y_hist = np.empty((N_fixations, max_length_trains))
self.t_hist = np.empty((N_fixations, max_length_trains))
self.x_hist[:] = np.nan
self.y_hist[:] = np.nan
self.t_hist[:] = np.nan
self.n = np.empty(N_trains, dtype=int)
self.lengths = np.empty(N_trains, dtype=int)
self.n = np.empty(N_fixations, dtype=int)
self.lengths = np.empty(N_fixations, dtype=int)
self.train_lengths = np.empty(len(self.train_xs), dtype=int)
self.subjects = np.empty(N_trains, dtype=int)
self.scanpath_index = np.empty(N_trains, dtype=int)
self.subjects = np.empty(N_fixations, dtype=int)
self.scanpath_index = np.empty(N_fixations, dtype=int)

out_index = 0
# TODO: maybe implement in numba?
# probably best: have function fill_fixation_data(scanpath_data, fixation_data, hist_data=None)
for train_index in range(self.train_xs.shape[0]):
for train_index in range(len(self.train_xs)):
fix_length = len(remove_trailing_nans(self.train_xs[train_index]))
self.train_lengths[train_index] = fix_length
for fix_index in range(fix_length):
Expand Down Expand Up @@ -560,11 +567,11 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
if new_attribute_name in attributes:
raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name))
attribute_shape = [] if not value.any() else np.asarray(value[0]).shape
attributes[new_attribute_name] = np.empty([N_trains] + list(attribute_shape), dtype=value.dtype)
attributes[new_attribute_name] = np.empty([N_fixations] + list(attribute_shape), dtype=value.dtype)
self.auto_attributes.append(new_attribute_name)

out_index = 0
for train_index in range(self.train_xs.shape[0]):
for train_index in range(len(self.train_xs)):
fix_length = (1 - np.isnan(self.train_xs[train_index])).sum()
for fix_index in range(fix_length):
attributes[new_attribute_name][out_index] = self.scanpath_attributes[attribute_name][train_index]
Expand All @@ -575,7 +582,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
new_attribute_name = self.scanpath_attribute_mapping.get(attribute_name, attribute_name)
if new_attribute_name in attributes:
raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name))
attributes[new_attribute_name] = np.empty(N_trains)
attributes[new_attribute_name] = np.empty(N_fixations)
self.auto_attributes.append(new_attribute_name)

hist_attribute_name = new_attribute_name + '_hist'
Expand All @@ -585,7 +592,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
self.auto_attributes.append(hist_attribute_name)

out_index = 0
for train_index in range(self.train_xs.shape[0]):
for train_index in range(len(self.train_xs)):
fix_length = (1 - np.isnan(self.train_xs[train_index])).sum()
for fix_index in range(fix_length):
attributes[new_attribute_name][out_index] = self.scanpath_fixation_attributes[attribute_name][train_index, fix_index]
Expand All @@ -608,7 +615,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
self.full_nonfixations = None

def _check_train_lengths(self, other: VariableLengthArray):
if not len(self.train_xs) == len(other):
if not len(self.train_lengths) == len(other):
raise ValueError("Length of scanpaths has to match")
if not np.all(self.train_lengths == other.lengths):
raise ValueError("Lengths of scanpaths have to match")
Expand Down Expand Up @@ -694,7 +701,7 @@ def set_scanpath_attribute(self, name, data, fixation_attribute_name=None):
self.auto_attributes.append(new_attribute_name)

out_index = 0
for train_index in range(self.train_xs.shape[0]):
for train_index in range(len(self.train_xs)):
fix_length = (1 - np.isnan(self.train_xs[train_index])).sum()
for _ in range(fix_length):
self.attributes[new_attribute_name][out_index] = self.scanpath_attributes[name][train_index]
Expand Down Expand Up @@ -758,7 +765,7 @@ def fixation_trains(self):
"""Yield for every fixation train of the dataset:
xs, ys, ts, n, subject
"""
for i in range(self.train_xs.shape[0]):
for i in range(len(self.train_xs)):
length = (1 - np.isnan(self.train_xs[i])).sum()
xs = self.train_xs[i][:length]
ys = self.train_ys[i][:length]
Expand Down
60 changes: 38 additions & 22 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def assert_fixations_equal(f1, f2, crop_length=False):


def assert_fixation_trains_equal(scanpaths1, scanpaths2):
np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs)
np.testing.assert_array_equal(scanpaths1.train_ys, scanpaths2.train_ys)
np.testing.assert_array_equal(scanpaths1.train_xs, scanpaths2.train_xs)
assert_variable_length_array_equal(scanpaths1.train_xs, scanpaths2.train_xs)
assert_variable_length_array_equal(scanpaths1.train_ys, scanpaths2.train_ys)
assert_variable_length_array_equal(scanpaths1.train_ts, scanpaths2.train_ts)

np.testing.assert_array_equal(scanpaths1.train_ns, scanpaths2.train_ns)
np.testing.assert_array_equal(scanpaths1.train_subjects, scanpaths2.train_subjects)
np.testing.assert_array_equal(scanpaths1.train_lengths, scanpaths2.train_lengths)
Expand Down Expand Up @@ -121,11 +122,13 @@ def test_from_fixations(self):
)

# Test fixation trains
np.testing.assert_allclose(f.train_xs, [[0, 1, 2], [2, 2, np.nan], [1, 5, 3]])
np.testing.assert_allclose(f.train_ys, [[10, 11, 12], [12, 12, np.nan], [21, 25, 33]])
np.testing.assert_allclose(f.train_ts, [[0, 200, 600], [100, 400, np.nan], [50, 500, 900]])
np.testing.assert_allclose(f.train_ns, [0, 0, 1])
np.testing.assert_allclose(f.train_subjects, [0, 1, 1])

assert_variable_length_array_equal(f.train_xs, VariableLengthArray(xs_trains))
assert_variable_length_array_equal(f.train_ys, VariableLengthArray(ys_trains))
assert_variable_length_array_equal(f.train_ts, VariableLengthArray(ts_trains))

np.testing.assert_allclose(f.train_ns, ns)
np.testing.assert_allclose(f.train_subjects, subjects)

# Test conditional fixations
np.testing.assert_allclose(f.x, [0, 1, 2, 2, 2, 1, 5, 3])
Expand All @@ -134,14 +137,20 @@ def test_from_fixations(self):
np.testing.assert_allclose(f.n, [0, 0, 0, 0, 0, 1, 1, 1])
np.testing.assert_allclose(f.subjects, [0, 0, 0, 1, 1, 1, 1, 1])
np.testing.assert_allclose(f.lengths, [0, 1, 2, 0, 1, 0, 1, 2])
np.testing.assert_allclose(f.x_hist._data, [[np.nan, np.nan],
[0, np.nan],
[0, 1],
[np.nan, np.nan],
[2, np.nan],
[np.nan, np.nan],
[1, np.nan],
[1, 5]])

assert_variable_length_array_equal(
f.x_hist,
VariableLengthArray([
[],
[0],
[0, 1],
[],
[2],
[],
[1],
[1, 5]
])
)

def test_filter(self):
xs_trains = []
Expand Down Expand Up @@ -226,9 +235,11 @@ def test_save_and_load(self):
with open(filename, 'rb') as in_file:
f = pickle.load(in_file)
# Test fixation trains
np.testing.assert_allclose(f.train_xs, [[0, 1, 2], [2, 2, np.nan], [1, 5, 3]])
np.testing.assert_allclose(f.train_ys, [[10, 11, 12], [12, 12, np.nan], [21, 25, 33]])
np.testing.assert_allclose(f.train_ts, [[0, 200, 600], [100, 400, np.nan], [50, 500, 900]])

assert_variable_length_array_equal(f.train_xs, VariableLengthArray(xs_trains))
assert_variable_length_array_equal(f.train_ys, VariableLengthArray(ys_trains))
assert_variable_length_array_equal(f.train_ts, VariableLengthArray(ts_trains))

np.testing.assert_allclose(f.train_ns, [0, 0, 1])
np.testing.assert_allclose(f.train_subjects, [0, 1, 1])

Expand Down Expand Up @@ -499,26 +510,31 @@ def test_scanpath_fixation_attributes(fixation_trains):
def test_filter_fixation_trains(fixation_trains, scanpath_indices, fixation_indices):
sub_fixations = fixation_trains.filter_fixation_trains(scanpath_indices)

np.testing.assert_array_equal(
assert_variable_length_array_equal(
sub_fixations.train_xs,
fixation_trains.train_xs[scanpath_indices]
)
np.testing.assert_array_equal(

assert_variable_length_array_equal(
sub_fixations.train_ys,
fixation_trains.train_ys[scanpath_indices]
)
np.testing.assert_array_equal(

assert_variable_length_array_equal(
sub_fixations.train_ts,
fixation_trains.train_ts[scanpath_indices]
)

np.testing.assert_array_equal(
sub_fixations.train_ns,
fixation_trains.train_ns[scanpath_indices]
)

np.testing.assert_array_equal(
sub_fixations.some_attribute,
fixation_trains.some_attribute[fixation_indices]
)

np.testing.assert_array_equal(
sub_fixations.scanpath_attributes['task'],
fixation_trains.scanpath_attributes['task'][scanpath_indices]
Expand Down

0 comments on commit 0ff3421

Please sign in to comment.