Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose t_start in BaseRecording #3117

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,26 @@ def has_time_vector(self, segment_index=None):
return d["time_vector"] is not None

def set_times(self, times, segment_index=None, with_warning=True):
"""Set times for a recording segment.
"""Set times for a recording segment. Any existing times
will be overwritten.

Times can be manually set on the recording segment. If times are
not set, the sample index and sampling frequency are used to
calculate time. Otherwise, `t_start` or `time_vector` can be
provided:

`t_start` - the start time for the segment. The times for
this recording segment will be calculated as
t_start + sample_index * (1 / sampling_frequency)

`time_vector` - A vector of length segment.get_num_samples()
that holds the exact time for each sample in the recording.

Parameters
----------
times : 1d np.array
The time vector
times : float | 1d np.array
If `int`, this is the `t_start` for the segment,
otherwise, it is the time vector.
segment_index : int or None, default: None
The segment index (required for multi-segment)
with_warning : bool, default: True
Expand All @@ -472,11 +486,18 @@ def set_times(self, times, segment_index=None, with_warning=True):
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]

assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"
if isinstance(times, float) or isinstance(times, int):
rs.t_start = times
rs.time_vector = None
elif isinstance(times, np.ndarray):

rs.t_start = None
rs.time_vector = times.astype("float64", copy=False)
assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"

rs.t_start = None
rs.time_vector = times.astype("float64", copy=False)
else:
raise TypeError("`times` must be an integer / float (`t_start`) or " "numpy array (`time_vector`).")

if with_warning:
warnings.warn(
Expand Down
319 changes: 319 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,328 @@
import copy

import pytest
import numpy as np

from spikeinterface.core import generate_recording, generate_sorting
import spikeinterface.full as si


class TestTimeHandling:

# Fixtures #####
@pytest.fixture(scope="session")
def raw_recording(self):
"""
A three-segment raw recording without times added.
"""
durations = [10, 15, 20]
recording = generate_recording(num_channels=4, durations=durations)
return recording

@pytest.fixture(scope="session")
def time_vector_recording(self, raw_recording):
"""
Add time vectors to the recording, returning the
raw recording, recording with time vectors added to
segments, and list a the time vectors added to the recording.
"""
return self._get_time_vector_recording(raw_recording)

@pytest.fixture(scope="session")
def t_start_recording(self, raw_recording):
"""
Add a t_starts to the recording, returning the
raw recording, recording with t_starts added to segments,
and a list of the time vectors generated from adding the
t_start to the recording times.
"""
return self._get_t_start_recording(raw_recording)

def _get_time_vector_recording(self, raw_recording):
"""
Loop through all recording segments, adding a different time
vector to each segment. The time vector is the original times with
a t_start and irregularly spaced offsets to mimic irregularly
spaced timeseries data. Return the original recording,
recoridng with time vectors added and list including the added time vectors.
"""
times_recording = copy.deepcopy(raw_recording)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have clone for this as an extractor method but if you really require this, why make the raw recording fixture per session?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benefit of the raw_recording fixture is that durations only needs to be defined once, then copied as set_times() is in place. But I agree it is a lot of indirection and it is probably more readable to incorporate into the individual fixtures, possibly with DURATIONS=[...] set at the top of the script?

all_time_vectors = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = segment_index + 1 * 100
offsets = np.arange(times_recording.get_num_samples(segment_index)) * (
1 / times_recording.get_sampling_frequency()
)
time_vector = t_start + times_recording.get_times(segment_index) + offsets

all_time_vectors.append(time_vector)
times_recording.set_times(times=time_vector, segment_index=segment_index)

assert np.array_equal(
times_recording._recording_segments[segment_index].time_vector,
time_vector,
), "time_vector was not properly set during test setup"

return (raw_recording, times_recording, all_time_vectors)

def _get_t_start_recording(self, raw_recording):
"""
For each segment in the recording, add a different `t_start`.
Return a list of time vectors generating from the recording times
+ the t_starts.
"""
t_start_recording = copy.deepcopy(raw_recording)

all_t_starts = []
for segment_index in range(raw_recording.get_num_segments()):

t_start = (segment_index + 1) * 100

all_t_starts.append(t_start + t_start_recording.get_times(segment_index))
t_start_recording.set_times(times=t_start, segment_index=segment_index)

assert np.array_equal(
t_start_recording._recording_segments[segment_index].t_start,
t_start,
), "t_start was not properly set during test setup"

return (raw_recording, t_start_recording, all_t_starts)

def _get_fixture_data(self, request, fixture_name):
"""
A convenience function to get the data from a fixture
based on the name. This is used to allow parameterising
tests across fixtures.
"""
time_recording_fixture = request.getfixturevalue(fixture_name)
raw_recording, times_recording, all_times = time_recording_fixture
return (raw_recording, times_recording, all_times)

# Tests #####
def test_has_time_vector(self, time_vector_recording):
"""
Test the `has_time_vector` function returns `False` before
a time vector is added and `True` afterwards.
"""
raw_recording, times_recording, _ = time_vector_recording

for segment_idx in range(raw_recording.get_num_segments()):

assert raw_recording.has_time_vector(segment_idx) is False
assert times_recording.has_time_vector(segment_idx) is True

def test_get_durations(self, time_vector_recording, t_start_recording):
"""
Test the `get_durations` functions that return the total duration
for a segment. Test that it is correct after adding both `t_start`
or `time_vector` to the recording.
"""
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
_, tstart_recording, all_t_starts = t_start_recording

ts = 1 / raw_recording.get_sampling_frequency()

all_raw_durations = []
all_vector_durations = []
for segment_index in range(raw_recording.get_num_segments()):

# Test before `t_start` and `t_start` (`t_start` is just an offset,
# should not affect duration).
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts

assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)

# Test the duration from the time vector.
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts

assert tvector_recording.get_duration(segment_index) == vector_duration

all_raw_durations.append(raw_duration)
all_vector_durations.append(vector_duration)

# Finally test the total recording duration
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)

@pytest.mark.parametrize("mode", ["binary", "zarr"])
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path):
"""
Test `t_start` or `time_vector` is propagated to a saved recording,
by saving, reloading, and checking times are correct.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

folder_name = "recording"
recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name)

if mode == "zarr":
folder_name += ".zarr"
recording_load = si.load_extractor(tmp_path / folder_name)

self._check_times_match(recording_cache, all_times)
self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("sharedmem", [True, False])
def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem):
"""
Test t_start and time_vector are propagated to recording saved into memory.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

recording_load = times_recording.save(format="memory", sharedmem=sharedmem)

self._check_times_match(recording_load, all_times)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_propagated_to_select_segments(self, request, fixture_name):
"""
Test that when `recording.select_segments()` is used, the times
are propagated to the new recoridng object.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

for segment_index in range(times_recording.get_num_segments()):
segment = times_recording.select_segments(segment_index)
assert np.array_equal(segment.get_times(), all_times[segment_index])

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_times_propagated_to_sorting(self, request, fixture_name):
"""
Check that when attached to a sorting object, the times are propagated
to the object. This means that all spike times should respect the
`t_start` or `time_vector` added.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)
for segment_index in range(raw_recording.get_num_segments()):

if fixture_name == "time_vector_recording":
assert sorting.has_time_vector(segment_index=segment_index)

self._check_spike_times_are_correct(sorting, times_recording, segment_index)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_time_sample_converters(self, request, fixture_name):
"""
Test the `recording.sample_time_to_index` and
`recording.time_to_sample_index` convenience functions.
"""
raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name)
with pytest.raises(ValueError) as e:
times_recording.sample_index_to_time(0)
assert "Provide 'segment_index'" in str(e)

for segment_index in range(times_recording.get_num_segments()):

sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index))
time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index)

assert time_ == all_times[segment_index][sample_index]

new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index)

assert new_sample_index == sample_index

@pytest.mark.parametrize("time_type", ["time_vector", "t_start"])
@pytest.mark.parametrize("bounds", ["start", "middle", "end"])
def test_slice_recording(self, time_type, bounds):
"""
Test after `frame_slice` and `time_slice` a recording or
sorting (for `frame_slice`), the recording times are
correct with respect to the set `t_start` or `time_vector`.
"""
raw_recording = generate_recording(num_channels=4, durations=[10])

if time_type == "time_vector":
raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording)
else:
raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording)

sorting = self._get_sorting_with_recording_attached(
recording_for_durations=raw_recording, recording_to_attach=times_recording
)

# Take some different times, including min and max bounds of
# the recording, and some arbitaray times in the middle (20% and 80%).
if bounds == "start":
start_frame = 0
end_frame = int(times_recording.get_num_samples(0) * 0.8)
elif bounds == "end":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = times_recording.get_num_samples(0) - 1
elif bounds == "middle":
start_frame = int(times_recording.get_num_samples(0) * 0.2)
end_frame = int(times_recording.get_num_samples(0) * 0.8)

# Slice the recording and get the new times are correct
rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame)
sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame)

assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0)

# Test `time_slice`
start_time = times_recording.sample_index_to_time(start_frame)
end_time = times_recording.sample_index_to_time(end_frame)

rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time)

assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

# Helpers ####
def _check_times_match(self, recording, all_times):
"""
For every segment in a recording, check the `get_times()`
match the expected times in the list of time vectors, `all_times`.
"""
for segment_index in range(recording.get_num_segments()):
assert np.array_equal(recording.get_times(segment_index), all_times[segment_index])

def _check_spike_times_are_correct(self, sorting, times_recording, segment_index):
"""
For every unit in the `sorting`, for a particular segment, check that
the unit times match the times of the original recording as
retrieved with `get_times()`.
"""
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index)
rec_times = times_recording.get_times(segment_index=segment_index)

assert np.array_equal(
spike_times,
rec_times[spike_indexes],
)

def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach):
"""
Convenience function to create a sorting object with
a recording attached. Typically use the raw recordings
for the durations of which to make the sorter, as
the generate_sorter is not setup to handle the
(strange) edge case of the irregularly spaced
test time vectors.
"""
durations = [
recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments())
]

sorting = generate_sorting(num_units=10, durations=durations)

sorting.register_recording(recording_to_attach)
assert sorting.has_recording()

return sorting


# TODO: deprecate original implementations ###
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: this was messing up the diff so left for the end.

def test_time_handling(create_cache_folder):
cache_folder = create_cache_folder
durations = [[10], [10, 5]]
Expand Down
Loading