Skip to content

Commit

Permalink
Merge pull request #3509 from JoeZiminski/add_shift_time_function
Browse files Browse the repository at this point in the history
Add `shift start time` function.
  • Loading branch information
alejoe91 authored Nov 20, 2024
2 parents 681fb01 + 469b3b0 commit 3fd3d97
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
29 changes: 29 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,35 @@ def reset_times(self):
rs.t_start = None
rs.sampling_frequency = self.sampling_frequency

def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
"""
Shift all times by a scalar value.
Parameters
----------
shift : int | float
The shift to apply. If positive, times will be increased by `shift`.
e.g. shifting by 1 will be like the recording started 1 second later.
If negative, the start time will be decreased i.e. as if the recording
started earlier.
segment_index : int | None
The segment on which to shift the times.
If `None`, all segments will be shifted.
"""
if segment_index is None:
segments_to_shift = range(self.get_num_segments())
else:
segments_to_shift = (segment_index,)

for idx in segments_to_shift:
rs = self._recording_segments[idx]

if self.has_time_vector(segment_index=idx):
rs.time_vector += shift
else:
rs.t_start += shift

def sample_index_to_time(self, sample_ind, segment_index=None):
"""
Transform sample index into time in seconds
Expand Down
92 changes: 89 additions & 3 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class TestTimeHandling:
is generated on the fly. Both time representations are tested here.
"""

# Fixtures #####
# #########################################################################
# Fixtures
# #########################################################################

@pytest.fixture(scope="session")
def time_vector_recording(self):
"""
Expand Down Expand Up @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name):
raw_recording, times_recording, all_times = time_recording_fixture
return (raw_recording, times_recording, all_times)

# Tests #####
# #########################################################################
# Tests
# #########################################################################

def test_has_time_vector(self, time_vector_recording):
"""
Test the `has_time_vector` function returns `False` before
Expand Down Expand Up @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording

assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())

# Helpers ####
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("shift", [-123.456, 123.456])
def test_shift_time_all_segments(self, request, fixture_name, shift):
"""
Shift the times in every segment using the `None` default, then
check that every segment of the recording is shifted as expected.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

num_segments, orig_seg_data = self._store_all_times(times_recording)

times_recording.shift_times(shift) # use default `segment_index=None`

for idx in range(num_segments):
assert np.allclose(
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8
)

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
@pytest.mark.parametrize("shift", [-123.456, 123.456])
def test_shift_times_different_segments(self, request, fixture_name, shift):
"""
Shift each segment separately, and check the shifted segment only
is shifted as expected.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

num_segments, orig_seg_data = self._store_all_times(times_recording)

# For each segment, shift the segment only and check the
# times are updated as expected.
for idx in range(num_segments):

scaler = idx + 2
times_recording.shift_times(shift * scaler, segment_index=idx)

assert np.allclose(
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8
)

# Just do a little check that we are not
# accidentally changing some other segments,
# which should remain unchanged at this point in the loop.
if idx != num_segments - 1:
assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1))

@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
def test_save_and_load_time_shift(self, request, fixture_name, tmp_path):
"""
Save the shifted data and check the shift is propagated correctly.
"""
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)

shift = 100
times_recording.shift_times(shift=shift)

times_recording.save(folder=tmp_path / "my_file")

loaded_recording = si.load_extractor(tmp_path / "my_file")

for idx in range(times_recording.get_num_segments()):
assert np.array_equal(
times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx)
)

def _store_all_times(self, recording):
"""
Convenience function to store original times of all segments to a dict.
"""
num_segments = recording.get_num_segments()
seg_data = {}

for idx in range(num_segments):
seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx))

return num_segments, seg_data

# #########################################################################
# Helpers
# #########################################################################

def _check_times_match(self, recording, all_times):
"""
For every segment in a recording, check the `get_times()`
Expand Down

0 comments on commit 3fd3d97

Please sign in to comment.