diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index c1b1cd9f80..2b47601fc2 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -99,6 +99,7 @@ def __init__( if parent_recording_segment.time_vector is not None: time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] decimated_sampling_frequency = None + t_start = None else: time_vector = None if parent_recording_segment.t_start is None: diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index adfcbd0d4a..dd521cbe9b 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -11,13 +11,8 @@ @pytest.mark.parametrize("num_segments", [1, 2]) @pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) @pytest.mark.parametrize("decimation_factor", [1, 7, 50]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(num_segments)] - +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) parent_traces = [rec.get_traces(i) for i in range(num_segments)] @@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] @@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate()