Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't let decimate mess with times and skim tests #3519

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/spikeinterface/preprocessing/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,15 @@ def __init__(
f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames."
)
self._decimation_offset = decimation_offset
resample_rate = self._orig_samp_freq / self._decimation_factor
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor

BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate)
BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency)

# in case there was a time_vector, it will be dropped for sanity.
# This is not necessary but consistent with ResampleRecording
for parent_segment in recording._recording_segments:
parent_segment.time_vector = None
self.add_recording_segment(
DecimateRecordingSegment(
parent_segment,
resample_rate,
decimated_sampling_frequency,
self._orig_samp_freq,
decimation_factor,
decimation_offset,
Expand All @@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment):
def __init__(
self,
parent_recording_segment,
resample_rate,
decimated_sampling_frequency,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a breaking change no? Should we deprecate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but it's soo deep in the API that I'm 100% sure it wouldn't affect anyone's workflow.

If it were at the DecimateRecording level, than we should have worried about back-compatibility because of saved objects/JSON files. But since it's the segment which is instantiated on the fly we don't have to worry about it (and I think the naming is much more in line with the overll API)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely agree with the naming. And makes sense if it's deep. I wasn't sure if this was more on the private or public side, but makes sense that the Segment level is basically private.

parent_rate,
decimation_factor,
decimation_offset,
dtype,
):
if parent_recording_segment.t_start is None:
new_t_start = None
if parent_recording_segment.time_vector is not None:
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
decimated_sampling_frequency = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this None for? I think Joe has thought about this more so It's tricky for me to think about time vector vs t_start and when we want a frequency of None vs a value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently time_vector and sampling_frequency/t_start representation of time in the segment are still mutually exclusive, so we need to set the sampling_freq to None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay cool. Maybe we can have a chat about this at some point. I don't want to take up developer meeting time for this necessarily, but since I don't use the time api I don't know it well enough. :)

t_start = None
else:
new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate
time_vector = None
if parent_recording_segment.t_start is None:
t_start = None
else:
t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate)

# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
BaseRecordingSegment.__init__(
self,
sampling_frequency=resample_rate,
t_start=new_t_start,
self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector
)
self._parent_segment = parent_recording_segment
self._decimation_factor = decimation_factor
Expand Down
69 changes: 52 additions & 17 deletions src/spikeinterface/preprocessing/tests/test_decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,74 @@
import numpy as np


@pytest.mark.parametrize("N_segments", [1, 2])
@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101])
@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101])
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame):
rec = generate_recording()

segment_num_samps = [101 + i for i in range(N_segments)]

@pytest.mark.parametrize("num_segments", [1, 2])
@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101])
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
def test_decimate(num_segments, decimation_offset, decimation_factor):
segment_num_samps = [20000, 40000]
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)

parent_traces = [rec.get_traces(i) for i in range(N_segments)]
parent_traces = [rec.get_traces(i) for i in range(num_segments)]

if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor:
with pytest.raises(ValueError):
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
return

decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)]
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]

if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
for start_frame in [0, 1, 5, None, 1000]:
for end_frame in [0, 1, 5, None, 1000]:
if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))

for i in range(N_segments):
for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
assert np.all(
decimated_rec.get_traces(i, start_frame, end_frame)
== decimated_parent_traces[i][start_frame:end_frame]
)

for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
assert np.all(
decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame]
)


def test_decimate_with_times():
rec = generate_recording(durations=[5, 10])

# test with times
times = [rec.get_times(0) + 10, rec.get_times(1) + 20]
for i, t in enumerate(times):
rec.set_times(t, i)

decimation_factor = 2
decimation_offset = 1
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)

for segment_index in range(rec.get_num_segments()):
assert np.allclose(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we explicitly set the tolerance that we tolerate? We often get flakiness due to floating point rounding. does get_times return floats?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I think here all equal will work too. I'll push an update

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zm711 use np.testing/assert_array_almost_equal(..., decimal=10)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You think 10 decimal places? It looks like assert_array_almost_equal also deals with nan's which is nice. I vaguely remembering Heberto having to slowly relax one of these style of tests because it keep failing. 10 decimals seems super exact. Why not the default? (which is 7 based on most recent docs).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my tests the absolute error was e-15 :)

decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)

# test with t_start
rec = generate_recording(durations=[5, 10])
t_starts = [10, 20]
for t_start, rec_segment in zip(t_starts, rec._recording_segments):
rec_segment.t_start = t_start
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
for segment_index in range(rec.get_num_segments()):
assert np.allclose(
decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)


if __name__ == "__main__":
test_decimate()