Skip to content

Commit

Permalink
add joe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 12, 2023
1 parent 0a6ea00 commit b69a172
Showing 1 changed file with 74 additions and 2 deletions.
76 changes: 74 additions & 2 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from spikeinterface.core import generate_recording
from spikeinterface.core.numpyextractors import NumpyRecording

from spikeinterface.preprocessing import zero_channel_pad
from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, common_reference
from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_zero_padding_channel():
@pytest.fixture
def recording():
num_channels = 4
num_samples = 10
num_samples = 10000
rng = np.random.default_rng(seed=0)
traces = rng.random(size=(num_samples, num_channels))
traces_list = [traces]
Expand Down Expand Up @@ -258,5 +258,77 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording,
assert np.allclose(padded_traces_end, expected_zeros)


@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)])
def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end):
num_samples = recording.get_num_samples()
num_channels = recording.get_num_channels()

recording = bandpass_filter(recording, freq_min=300, freq_max=6000)

padded_recording = TracePaddedRecording(
parent_recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)

# Cycle through the whole recording, using get_traces() to pull chunks of
# size `step`. This emulates the processing of writing to a binary file.
# Data that lie within the padding region should be fill value only, while
# data from original trace should match exactly. Note that the step
# size must be chosen to retreieve data that is purely padding or original data
step = 1000
start_frames = np.arange(padded_recording.get_num_samples(), step=step)
end_frames = start_frames + step

for start_frame, end_frame in zip(start_frames, end_frames):
padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame)

end_padding_region_first_idx = padding_start + num_samples

if padding_start <= start_frame < end_padding_region_first_idx:
original_trace = recording.get_traces(
start_frame=start_frame - padding_start, end_frame=end_frame - padding_start
)
assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10)
else:
assert np.all(padded_trace == padded_recording.fill_value)


@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)])
def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end):
num_samples = recording.get_num_samples()
num_channels = recording.get_num_channels()

recording = bandpass_filter(recording, freq_min=300, freq_max=6000)

padded_recording = TracePaddedRecording(
parent_recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)

# Cycle through the whole recording, using get_traces() to pull chunks of
# size `step`. This emulates the processing of writing to a binary file.
# Data that lie within the padding region should be fill value only, while
# data from original trace should match exactly. Note that the step
# size must be chosen to retreieve data that is purely padding or original data
step = 1000
start_frames = np.arange(padded_recording.get_num_samples(), step=step)
end_frames = start_frames + step

for start_frame, end_frame in zip(start_frames, end_frames):
padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame)

end_padding_region_first_idx = padding_start + num_samples

if padding_start <= start_frame < end_padding_region_first_idx:
original_trace = recording.get_traces(
start_frame=start_frame - padding_start, end_frame=end_frame - padding_start
)
assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10)
else:
assert np.all(padded_trace == padded_recording.fill_value)


if __name__ == "__main__":
test_zero_padding_channel()

0 comments on commit b69a172

Please sign in to comment.