Skip to content

Commit

Permalink
Remove failing pickle test for obsolete class, add test for new class
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 14, 2024
1 parent 66326ec commit a788b06
Showing 1 changed file with 13 additions and 51 deletions.
64 changes: 13 additions & 51 deletions tests/datasets/test_fixations.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,57 +161,6 @@ def test_scanpaths(self):
# cum_inds = inds[inds2]
# compare_fixations_subset(__f, f, cum_inds)

def test_save_and_load(self):
xs_trains = [
[0, 1, 2],
[2, 2],
[1, 5, 3]]
ys_trains = [
[10, 11, 12],
[12, 12],
[21, 25, 33]]
ts_trains = [
[0, 200, 600],
[100, 400],
[50, 500, 900]]
ns = [0, 0, 1]
subjects = [0, 1, 1]
# Create /Fixations
f = pysaliency.FixationTrains.from_fixation_trains(xs_trains, ys_trains, ts_trains, ns, subjects)

filename = os.path.join(self.data_path, 'fixation.pydat')
with open(filename, 'wb') as out_file:
pickle.dump(f, out_file)

with open(filename, 'rb') as in_file:
f = pickle.load(in_file)
# Test fixation trains

assert_variable_length_array_equal(f.train_xs, VariableLengthArray(xs_trains))
assert_variable_length_array_equal(f.train_ys, VariableLengthArray(ys_trains))
assert_variable_length_array_equal(f.train_ts, VariableLengthArray(ts_trains))

np.testing.assert_allclose(f.train_ns, [0, 0, 1])
np.testing.assert_allclose(f.train_subjects, [0, 1, 1])

# Test conditional fixations
np.testing.assert_allclose(f.x, [0, 1, 2, 2, 2, 1, 5, 3])
np.testing.assert_allclose(f.y, [10, 11, 12, 12, 12, 21, 25, 33])
np.testing.assert_allclose(f.t, [0, 200, 600, 100, 400, 50, 500, 900])
np.testing.assert_allclose(f.n, [0, 0, 0, 0, 0, 1, 1, 1])
np.testing.assert_allclose(f.subject, [0, 0, 0, 1, 1, 1, 1, 1])
np.testing.assert_allclose(f.scanpath_history_length, [0, 1, 2, 0, 1, 0, 1, 2])
np.testing.assert_allclose(f.x_hist._data,
[[np.nan, np.nan],
[0, np.nan],
[0, 1],
[np.nan, np.nan],
[2, np.nan],
[np.nan, np.nan],
[1, np.nan],
[1, 5]])



@pytest.fixture
def scanpath_fixations() -> ScanpathFixations:
Expand Down Expand Up @@ -309,6 +258,19 @@ def test_copy_fixations(fixation_trains):
assert_fixations_equal(copied_fixations, fixations)


def test_write_read_scanpath_fixations_pickle(tmp_path, scanpath_fixations):
filename = tmp_path / 'scanpath_fixations.pydat'
with open(filename, 'wb') as out_file:
pickle.dump(scanpath_fixations, out_file)

with open(filename, 'rb') as in_file:
new_scanpath_fixations = pickle.load(in_file)

# make sure there is no sophisticated caching...
assert scanpath_fixations is not new_scanpath_fixations
assert_scanpath_fixations_equal(scanpath_fixations, new_scanpath_fixations)


def test_write_read_scanpath_fixations_pathlib(tmp_path, scanpath_fixations):
filename = tmp_path / 'scanpath_fixations.hdf5'
scanpath_fixations.to_hdf5(filename)
Expand Down

0 comments on commit a788b06

Please sign in to comment.