Skip to content

Commit

Permalink
More skimming and test decimate with times
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Nov 5, 2024
1 parent 2ba37a8 commit f900118
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/preprocessing/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
if parent_recording_segment.time_vector is not None:
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
decimated_sampling_frequency = None
t_start = None
else:
time_vector = None
if parent_recording_segment.t_start is None:
Expand Down
57 changes: 46 additions & 11 deletions src/spikeinterface/preprocessing/tests/test_decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@
@pytest.mark.parametrize("num_segments", [1, 2])
@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101])
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame):
rec = generate_recording()

segment_num_samps = [101 + i for i in range(num_segments)]

def test_decimate(num_segments, decimation_offset, decimation_factor):
segment_num_samps = [20000, 40000]
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)

parent_traces = [rec.get_traces(i) for i in range(num_segments)]
Expand All @@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]

if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
for start_frame in [0, 1, 5, None, 1000]:
for end_frame in [0, 1, 5, None, 1000]:
if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))

for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
assert np.all(
decimated_rec.get_traces(i, start_frame, end_frame)
== decimated_parent_traces[i][start_frame:end_frame]
)

for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
Expand All @@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
)


def test_decimate_with_times():
rec = generate_recording(durations=[5, 10])

# test with times
times = [rec.get_times(0) + 10, rec.get_times(1) + 20]
for i, t in enumerate(times):
rec.set_times(t, i)

decimation_factor = 2
decimation_offset = 1
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)

for segment_index in range(rec.get_num_segments()):
assert np.allclose(
decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)

# test with t_start
rec = generate_recording(durations=[5, 10])
t_starts = [10, 20]
for t_start, rec_segment in zip(t_starts, rec._recording_segments):
rec_segment.t_start = t_start
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
for segment_index in range(rec.get_num_segments()):
assert np.allclose(
decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)


if __name__ == "__main__":
test_decimate()

0 comments on commit f900118

Please sign in to comment.