Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 12, 2023
1 parent aeeb580 commit dce34a8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
13 changes: 7 additions & 6 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording,
expected_zeros = np.zeros((number_of_paded_frames_at_end, num_channels))
assert np.allclose(padded_traces_end, expected_zeros)


@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)])
def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end):

num_samples = recording.get_num_samples()
num_channels = recording.get_num_channels()

Expand All @@ -281,21 +281,21 @@ def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recor
end_frames = start_frames + step

for start_frame, end_frame in zip(start_frames, end_frames):

padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame)

end_padding_region_first_idx = padding_start + num_samples

if padding_start <= start_frame < end_padding_region_first_idx:
original_trace = recording.get_traces(start_frame=start_frame - padding_start, end_frame=end_frame - padding_start)
original_trace = recording.get_traces(
start_frame=start_frame - padding_start, end_frame=end_frame - padding_start
)
assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10)
else:
assert np.all(padded_trace == padded_recording.fill_value)


@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)])
def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recording, padding_start, padding_end):

num_samples = recording.get_num_samples()
num_channels = recording.get_num_channels()

Expand All @@ -317,13 +317,14 @@ def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(recor
end_frames = start_frames + step

for start_frame, end_frame in zip(start_frames, end_frames):

padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame)

end_padding_region_first_idx = padding_start + num_samples

if padding_start <= start_frame < end_padding_region_first_idx:
original_trace = recording.get_traces(start_frame=start_frame - padding_start, end_frame=end_frame - padding_start)
original_trace = recording.get_traces(
start_frame=start_frame - padding_start, end_frame=end_frame - padding_start
)
assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10)
else:
assert np.all(padded_trace == padded_recording.fill_value)
Expand Down
11 changes: 9 additions & 2 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def get_traces(self, start_frame, end_frame, channel_indices):
output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype)

# If start and end frame are outside of the original data region (e.g. for Kilosort), return only paddding
if start_frame > self.num_samples_in_original_segment + self.padding_start and end_frame > self.num_samples_in_original_segment + self.padding_start:
if (
start_frame > self.num_samples_in_original_segment + self.padding_start
and end_frame > self.num_samples_in_original_segment + self.padding_start
):
return output_traces

# After the padding, the original traces are placed in the middle until the end of the original traces
Expand Down Expand Up @@ -124,7 +127,11 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices):
original_start_frame = max(start_frame - self.padding_start, 0)
original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment)

original_traces = self.parent_recording_segment.get_traces(start_frame=original_start_frame, end_frame=original_end_frame, channel_indices=channel_indices,) # BREAKPOINT HERE!!
original_traces = self.parent_recording_segment.get_traces(
start_frame=original_start_frame,
end_frame=original_end_frame,
channel_indices=channel_indices,
) # BREAKPOINT HERE!!

return original_traces

Expand Down

0 comments on commit dce34a8

Please sign in to comment.