diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 75d64b0088..a2fc890c74 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -6,7 +6,7 @@ from spikeinterface.core import generate_recording from spikeinterface.core.numpyextractors import NumpyRecording -from spikeinterface.preprocessing import zero_channel_pad +from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, common_reference from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording if hasattr(pytest, "global_test_folder"): @@ -39,7 +39,7 @@ def test_zero_padding_channel(): @pytest.fixture def recording(): num_channels = 4 - num_samples = 10 + num_samples = 10000 rng = np.random.default_rng(seed=0) traces = rng.random(size=(num_samples, num_channels)) traces_list = [traces] @@ -258,5 +258,77 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording, assert np.allclose(padded_traces_end, expected_zeros) +@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)]) +def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + recording = bandpass_filter(recording, freq_min=300, freq_max=6000) + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Cycle through the whole recording, using get_traces() to pull chunks of + # size `step`. This emulates the processing of writing to a binary file. + # Data that lie within the padding region should be fill value only, while + # data from original trace should match exactly. Note that the step + # size must be chosen to retreieve data that is purely padding or original data + step = 1000 + start_frames = np.arange(padded_recording.get_num_samples(), step=step) + end_frames = start_frames + step + + for start_frame, end_frame in zip(start_frames, end_frames): + padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + + end_padding_region_first_idx = padding_start + num_samples + + if padding_start <= start_frame < end_padding_region_first_idx: + original_trace = recording.get_traces( + start_frame=start_frame - padding_start, end_frame=end_frame - padding_start + ) + assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10) + else: + assert np.all(padded_trace == padded_recording.fill_value) + + +@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)]) +def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + recording = bandpass_filter(recording, freq_min=300, freq_max=6000) + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Cycle through the whole recording, using get_traces() to pull chunks of + # size `step`. This emulates the processing of writing to a binary file. + # Data that lie within the padding region should be fill value only, while + # data from original trace should match exactly. Note that the step + # size must be chosen to retreieve data that is purely padding or original data + step = 1000 + start_frames = np.arange(padded_recording.get_num_samples(), step=step) + end_frames = start_frames + step + + for start_frame, end_frame in zip(start_frames, end_frames): + padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + + end_padding_region_first_idx = padding_start + num_samples + + if padding_start <= start_frame < end_padding_region_first_idx: + original_trace = recording.get_traces( + start_frame=start_frame - padding_start, end_frame=end_frame - padding_start + ) + assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10) + else: + assert np.all(padded_trace == padded_recording.fill_value) + + if __name__ == "__main__": test_zero_padding_channel()