diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 9a9747bf0b..0ea9426674 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -498,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None): rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) - def _save(self, format="binary", verbose: bool = False, **save_kwargs): + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for segment_index, rs in enumerate(self._recording_segments): + for rs in self._recording_segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) - has_time_vectors.append(d["time_vector"] is not None) if all(t_start is None for t_start in t_starts): t_starts = None + return t_starts + def _get_time_vectors(self): + time_vectors = [] + for rs in self._recording_segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + + def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -572,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] - if time_vector is not None: - cached._recording_segments[segment_index].time_vector = time_vector + time_vectors = self._get_time_vectors() + if time_vectors is not None: + for segment_index, time_vector in enumerate(time_vectors): + if time_vector is not None: + cached.set_times(time_vector, segment_index=segment_index) return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 09ba743a8c..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -83,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N @staticmethod def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + + t_starts = source_recording._get_t_starts() + if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array # this can lead to problem @@ -91,13 +94,14 @@ def from_recording(source_recording, **job_kwargs): for shm in shms: shm.close() shm.unlink() - # TODO later : propagte t_starts ? + recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -206,7 +210,7 @@ def __del__(self): def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) - # TODO later : propagte t_starts ? + t_starts = source_recording._get_t_starts() recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], @@ -214,7 +218,7 @@ def from_recording(source_recording, **job_kwargs): dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, ) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 487a893096..049d5ab6e5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,69 +1,289 @@ +import copy + import pytest import numpy as np from spikeinterface.core import generate_recording, generate_sorting +import spikeinterface.full as si + +class TestTimeHandling: + """ + This class tests how time is handled in SpikeInterface. Under the hood, + time can be represented as a full `time_vector` or only as + `t_start` attribute on segments from which a vector of times + is generated on the fly. Both time representations are tested here. + """ -def test_time_handling(create_cache_folder): - cache_folder = create_cache_folder - durations = [[10], [10, 5]] + # Fixtures ##### + @pytest.fixture(scope="session") + def time_vector_recording(self): + """ + Add time vectors to the recording, returning the + raw recording, recording with time vectors added to + segments, and list a the time vectors added to the recording. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - # test multi-segment - for i, dur in enumerate(durations): - rec = generate_recording(num_channels=4, durations=dur) - sort = generate_sorting(num_units=10, durations=dur) + return self._get_time_vector_recording(raw_recording) - for segment_index in range(rec.get_num_segments()): - original_times = rec.get_times(segment_index=segment_index) - new_times = original_times + 5 - rec.set_times(new_times, segment_index=segment_index) + @pytest.fixture(scope="session") + def t_start_recording(self): + """ + Add a t_starts to the recording, returning the + raw recording, recording with t_starts added to segments, + and a list of the time vectors generated from adding the + t_start to the recording times. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - sort.register_recording(rec) - assert sort.has_recording() + return self._get_t_start_recording(raw_recording) - rec_cache = rec.save(folder=cache_folder / f"rec{i}") + def _get_time_vector_recording(self, raw_recording): + """ + Loop through all recording segments, adding a different time + vector to each segment. The time vector is the original times with + a t_start and irregularly spaced offsets to mimic irregularly + spaced timeseries data. Return the original recording, + recoridng with time vectors added and list including the added time vectors. + """ + times_recording = copy.deepcopy(raw_recording) + all_time_vectors = [] + for segment_index in range(raw_recording.get_num_segments()): - for segment_index in range(sort.get_num_segments()): - assert rec.has_time_vector(segment_index=segment_index) - assert sort.has_time_vector(segment_index=segment_index) + t_start = segment_index + 1 * 100 - # times are correctly saved by the recording - assert np.allclose( - rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index) + some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( + 1 / times_recording.get_sampling_frequency() ) - # spike times are correctly adjusted - for u in sort.get_unit_ids(): - spike_times = sort.get_unit_spike_train(u, segment_index=segment_index, return_times=True) - rec_times = rec.get_times(segment_index=segment_index) - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + offsets = np.cumsum(some_small_increasing_numbers) + time_vector = t_start + times_recording.get_times(segment_index) + offsets + + all_time_vectors.append(time_vector) + times_recording.set_times(times=time_vector, segment_index=segment_index) + + assert np.array_equal( + times_recording._recording_segments[segment_index].time_vector, + time_vector, + ), "time_vector was not properly set during test setup" + + return (raw_recording, times_recording, all_time_vectors) + + def _get_t_start_recording(self, raw_recording): + """ + For each segment in the recording, add a different `t_start`. + Return a list of time vectors generating from the recording times + + the t_starts. + """ + t_start_recording = copy.deepcopy(raw_recording) + + all_t_starts = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = (segment_index + 1) * 100 + + all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) + t_start_recording._recording_segments[segment_index].t_start = t_start + + return (raw_recording, t_start_recording, all_t_starts) + + def _get_fixture_data(self, request, fixture_name): + """ + A convenience function to get the data from a fixture + based on the name. This is used to allow parameterising + tests across fixtures. + """ + time_recording_fixture = request.getfixturevalue(fixture_name) + raw_recording, times_recording, all_times = time_recording_fixture + return (raw_recording, times_recording, all_times) + + # Tests ##### + def test_has_time_vector(self, time_vector_recording): + """ + Test the `has_time_vector` function returns `False` before + a time vector is added and `True` afterwards. + """ + raw_recording, times_recording, _ = time_vector_recording + + for segment_idx in range(raw_recording.get_num_segments()): + + assert raw_recording.has_time_vector(segment_idx) is False + assert times_recording.has_time_vector(segment_idx) is True + + @pytest.mark.parametrize("mode", ["binary", "zarr"]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path): + """ + Test `t_start` or `time_vector` is propagated to a saved recording, + by saving, reloading, and checking times are correct. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + folder_name = "recording" + recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name) + + if mode == "zarr": + folder_name += ".zarr" + recording_load = si.load_extractor(tmp_path / folder_name) + + self._check_times_match(recording_cache, all_times) + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("sharedmem", [True, False]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): + """ + Test t_start and time_vector are propagated to recording saved into memory. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + recording_load = times_recording.save(format="memory", sharedmem=sharedmem) + self._check_times_match(recording_load, all_times) -def test_frame_slicing(): - duration = [10] + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_propagated_to_select_segments(self, request, fixture_name): + """ + Test that when `recording.select_segments()` is used, the times + are propagated to the new recoridng object. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) - rec = generate_recording(num_channels=4, durations=duration) - sort = generate_sorting(num_units=10, durations=duration) + for segment_index in range(times_recording.get_num_segments()): + segment = times_recording.select_segments(segment_index) + assert np.array_equal(segment.get_times(), all_times[segment_index]) - original_times = rec.get_times() - new_times = original_times + 5 - rec.set_times(new_times) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_sorting(self, request, fixture_name): + """ + Check that when attached to a sorting object, the times are propagated + to the object. This means that all spike times should respect the + `t_start` or `time_vector` added. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + for segment_index in range(raw_recording.get_num_segments()): - sort.register_recording(rec) + if fixture_name == "time_vector_recording": + assert sorting.has_time_vector(segment_index=segment_index) - start_frame = 3 * rec.get_sampling_frequency() - end_frame = 7 * rec.get_sampling_frequency() + self._check_spike_times_are_correct(sorting, times_recording, segment_index) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_sample_converters(self, request, fixture_name): + """ + Test the `recording.sample_time_to_index` and + `recording.time_to_sample_index` convenience functions. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + with pytest.raises(ValueError) as e: + times_recording.sample_index_to_time(0) + assert "Provide 'segment_index'" in str(e) + + for segment_index in range(times_recording.get_num_segments()): + + sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index)) + time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index) + + assert time_ == all_times[segment_index][sample_index] + + new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index) + + assert new_sample_index == sample_index + + @pytest.mark.parametrize("time_type", ["time_vector", "t_start"]) + @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) + def test_slice_recording(self, time_type, bounds): + """ + Test times are correct after applying `frame_slice` or `time_slice` + to a recording or sorting (for `frame_slice`). The the recording times + should be correct with respect to the set `t_start` or `time_vector`. + """ + raw_recording = generate_recording(num_channels=4, durations=[10]) + + if time_type == "time_vector": + raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording) + else: + raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording) + + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + + # Take some different times, including min and max bounds of + # the recording, and some arbitaray times in the middle (20% and 80%). + if bounds == "start": + start_frame = 0 + end_frame = int(times_recording.get_num_samples(0) * 0.8) + elif bounds == "end": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = times_recording.get_num_samples(0) - 1 + elif bounds == "middle": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = int(times_recording.get_num_samples(0) * 0.8) + + # Slice the recording and get the new times are correct + rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + + assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0) + + # Test `time_slice` + start_time = times_recording.sample_index_to_time(start_frame) + end_time = times_recording.sample_index_to_time(end_frame) + + rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time) + + assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + # Helpers #### + def _check_times_match(self, recording, all_times): + """ + For every segment in a recording, check the `get_times()` + match the expected times in the list of time vectors, `all_times`. + """ + for segment_index in range(recording.get_num_segments()): + assert np.array_equal(recording.get_times(segment_index), all_times[segment_index]) + + def _check_spike_times_are_correct(self, sorting, times_recording, segment_index): + """ + For every unit in the `sorting`, for a particular segment, check that + the unit times match the times of the original recording as + retrieved with `get_times()`. + """ + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) + spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + rec_times = times_recording.get_times(segment_index=segment_index) + + assert np.array_equal( + spike_times, + rec_times[spike_indexes], + ) - rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) - sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) + def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): + """ + Convenience function to create a sorting object with + a recording attached. Typically use the raw recordings + for the durations of which to make the sorter, as + the generate_sorter is not setup to handle the + (strange) edge case of the irregularly spaced + test time vectors. + """ + durations = [ + recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments()) + ] - for u in sort_slice.get_unit_ids(): - spike_times = sort_slice.get_unit_spike_train(u, return_times=True) - rec_times = rec_slice.get_times() - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + sorting = generate_sorting(num_units=10, durations=durations) + sorting.register_recording(recording_to_attach) + assert sorting.has_recording() -if __name__ == "__main__": - test_frame_slicing() + return sorting