From dce34a8f48592fdf0085e3998d22e2506f62386e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:18:46 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/tests/test_zero_padding.py | 13 +++++++------ .../preprocessing/zero_channel_pad.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 7bb8ae8aa6..a2fc890c74 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -257,9 +257,9 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording, expected_zeros = np.zeros((number_of_paded_frames_at_end, num_channels)) assert np.allclose(padded_traces_end, expected_zeros) + @pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)]) def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end): - num_samples = recording.get_num_samples() num_channels = recording.get_num_channels() @@ -281,13 +281,14 @@ def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recor end_frames = start_frames + step for start_frame, end_frame in zip(start_frames, end_frames): - padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) end_padding_region_first_idx = padding_start + num_samples if padding_start <= start_frame < end_padding_region_first_idx: - original_trace = recording.get_traces(start_frame=start_frame - padding_start, end_frame=end_frame - padding_start) + original_trace = recording.get_traces( + start_frame=start_frame - padding_start, end_frame=end_frame - padding_start + ) assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10) else: assert np.all(padded_trace == padded_recording.fill_value) @@ -295,7 +296,6 @@ def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recor @pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)]) def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end): - num_samples = recording.get_num_samples() num_channels = recording.get_num_channels() @@ -317,13 +317,14 @@ def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recor end_frames = start_frames + step for start_frame, end_frame in zip(start_frames, end_frames): - padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) end_padding_region_first_idx = padding_start + num_samples if padding_start <= start_frame < end_padding_region_first_idx: - original_trace = recording.get_traces(start_frame=start_frame - padding_start, end_frame=end_frame - padding_start) + original_trace = recording.get_traces( + start_frame=start_frame - padding_start, end_frame=end_frame - padding_start + ) assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10) else: assert np.all(padded_trace == padded_recording.fill_value) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index ae5f3c9ab0..eaac91bb18 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -95,7 +95,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype) # If start and end frame are outside of the original data region (e.g. for Kilosort), return only paddding - if start_frame > self.num_samples_in_original_segment + self.padding_start and end_frame > self.num_samples_in_original_segment + self.padding_start: + if ( + start_frame > self.num_samples_in_original_segment + self.padding_start + and end_frame > self.num_samples_in_original_segment + self.padding_start + ): return output_traces # After the padding, the original traces are placed in the middle until the end of the original traces @@ -124,7 +127,11 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices): original_start_frame = max(start_frame - self.padding_start, 0) original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment) - original_traces = self.parent_recording_segment.get_traces(start_frame=original_start_frame, end_frame=original_end_frame, channel_indices=channel_indices,) # BREAKPOINT HERE!! + original_traces = self.parent_recording_segment.get_traces( + start_frame=original_start_frame, + end_frame=original_end_frame, + channel_indices=channel_indices, + ) # BREAKPOINT HERE!! return original_traces