diff --git a/tests/datasets/test_fixations.py b/tests/datasets/test_fixations.py index 9fc305f..43fc04f 100644 --- a/tests/datasets/test_fixations.py +++ b/tests/datasets/test_fixations.py @@ -161,57 +161,6 @@ def test_scanpaths(self): # cum_inds = inds[inds2] # compare_fixations_subset(__f, f, cum_inds) - def test_save_and_load(self): - xs_trains = [ - [0, 1, 2], - [2, 2], - [1, 5, 3]] - ys_trains = [ - [10, 11, 12], - [12, 12], - [21, 25, 33]] - ts_trains = [ - [0, 200, 600], - [100, 400], - [50, 500, 900]] - ns = [0, 0, 1] - subjects = [0, 1, 1] - # Create /Fixations - f = pysaliency.FixationTrains.from_fixation_trains(xs_trains, ys_trains, ts_trains, ns, subjects) - - filename = os.path.join(self.data_path, 'fixation.pydat') - with open(filename, 'wb') as out_file: - pickle.dump(f, out_file) - - with open(filename, 'rb') as in_file: - f = pickle.load(in_file) - # Test fixation trains - - assert_variable_length_array_equal(f.train_xs, VariableLengthArray(xs_trains)) - assert_variable_length_array_equal(f.train_ys, VariableLengthArray(ys_trains)) - assert_variable_length_array_equal(f.train_ts, VariableLengthArray(ts_trains)) - - np.testing.assert_allclose(f.train_ns, [0, 0, 1]) - np.testing.assert_allclose(f.train_subjects, [0, 1, 1]) - - # Test conditional fixations - np.testing.assert_allclose(f.x, [0, 1, 2, 2, 2, 1, 5, 3]) - np.testing.assert_allclose(f.y, [10, 11, 12, 12, 12, 21, 25, 33]) - np.testing.assert_allclose(f.t, [0, 200, 600, 100, 400, 50, 500, 900]) - np.testing.assert_allclose(f.n, [0, 0, 0, 0, 0, 1, 1, 1]) - np.testing.assert_allclose(f.subject, [0, 0, 0, 1, 1, 1, 1, 1]) - np.testing.assert_allclose(f.scanpath_history_length, [0, 1, 2, 0, 1, 0, 1, 2]) - np.testing.assert_allclose(f.x_hist._data, - [[np.nan, np.nan], - [0, np.nan], - [0, 1], - [np.nan, np.nan], - [2, np.nan], - [np.nan, np.nan], - [1, np.nan], - [1, 5]]) - - @pytest.fixture def scanpath_fixations() -> ScanpathFixations: @@ -309,6 +258,19 @@ def test_copy_fixations(fixation_trains): assert_fixations_equal(copied_fixations, fixations) +def test_write_read_scanpath_fixations_pickle(tmp_path, scanpath_fixations): + filename = tmp_path / 'scanpath_fixations.pydat' + with open(filename, 'wb') as out_file: + pickle.dump(scanpath_fixations, out_file) + + with open(filename, 'rb') as in_file: + new_scanpath_fixations = pickle.load(in_file) + + # make sure there is no sophisticated caching... + assert scanpath_fixations is not new_scanpath_fixations + assert_scanpath_fixations_equal(scanpath_fixations, new_scanpath_fixations) + + def test_write_read_scanpath_fixations_pathlib(tmp_path, scanpath_fixations): filename = tmp_path / 'scanpath_fixations.hdf5' scanpath_fixations.to_hdf5(filename)