Skip to content

Commit

Permalink
Merge branch 'main' into add_errors_to_import_time
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Jun 1, 2024
2 parents c70fedb + dd8bac5 commit 8f673b5
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 183 deletions.
20 changes: 10 additions & 10 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_trace_padded_recording_full_trace(recording, padding_start, padding_end
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_trace_padded_recording_full_trace_with_channel_indices(recording, paddi
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_trace_padded_recording_retrieve_original_trace(recording, padding_start
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -129,7 +129,7 @@ def test_trace_padded_recording_retrieve_partial_original_trace(recording, paddi
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -156,7 +156,7 @@ def test_trace_padded_recording_retrieve_start_padding_and_partial_original_trac
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_trace_padded_recording_retrieve_end_padding_and_partial_original_trace(
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording,
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_trace_padded_recording_retrieve_only_start_padding(recording, padding_s
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -281,7 +281,7 @@ def test_trace_padded_recording_retrieve_only_end_padding(recording, padding_sta
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_trace_padded_recording_retrieve_only_end_padding_with_preprocessing(
recording = phase_shift(recording)

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down
63 changes: 27 additions & 36 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TracePaddedRecording(BasePreprocessor):
Parameters
----------
parent_recording_segment : BaseRecording
recording_segment : BaseRecording
The parent recording segment from which the traces are to be retrieved.
padding_start : int, default: 0
The amount of padding to add to the left of the traces. It has to be non-negative.
Expand All @@ -29,19 +29,17 @@ class TracePaddedRecording(BasePreprocessor):
The value to pad with
"""

def __init__(
self, parent_recording: BaseRecording, padding_start: int = 0, padding_end: int = 0, fill_value: float = 0.0
):
def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end: int = 0, fill_value: float = 0.0):
assert padding_end >= 0 and padding_start >= 0, "Paddings must be >= 0"
super().__init__(recording=parent_recording)
super().__init__(recording=recording)

self.padding_start = padding_start
self.padding_end = padding_end
self.fill_value = fill_value
for segment in parent_recording._recording_segments:
for segment in recording._recording_segments:
recording_segment = TracePaddedRecordingSegment(
segment,
parent_recording.get_num_channels(),
recording.get_num_channels(),
self.dtype,
self.padding_start,
self.padding_end,
Expand All @@ -50,7 +48,7 @@ def __init__(
self.add_recording_segment(recording_segment)

self._kwargs = dict(
parent_recording=parent_recording,
parent_recording=recording,
padding_start=padding_start,
padding_end=padding_end,
fill_value=fill_value,
Expand All @@ -60,21 +58,21 @@ def __init__(
class TracePaddedRecordingSegment(BasePreprocessorSegment):
def __init__(
self,
parent_recording_segment: BaseRecordingSegment,
recording_segment: BaseRecordingSegment,
num_channels,
dtype,
paddign_left,
padding_left,
padding_end,
fill_value,
):
self.padding_start = paddign_left
self.padding_start = padding_left
self.padding_end = padding_end
self.fill_value = fill_value
self.num_channels = num_channels
self.num_samples_in_original_segment = parent_recording_segment.get_num_samples()
self.num_samples_in_original_segment = recording_segment.get_num_samples()
self.dtype = dtype

super().__init__(parent_recording_segment=parent_recording_segment)
super().__init__(parent_recording_segment=recording_segment)

def get_traces(self, start_frame, end_frame, channel_indices):
if start_frame is None:
Expand Down Expand Up @@ -146,12 +144,12 @@ class ZeroChannelPaddedRecording(BaseRecording):
name = "zero_channel_pad"
installed = True

def __init__(self, parent_recording: BaseRecording, num_channels: int, channel_mapping: Union[list, None] = None):
def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: Union[list, None] = None):
"""Pads a recording with channels that contain only zero.
Parameters
----------
parent_recording : BaseRecording
recording : BaseRecording
recording to zero-pad
num_channels : int
Total number of channels in the zero-channel-padded recording
Expand All @@ -160,51 +158,44 @@ def __init__(self, parent_recording: BaseRecording, num_channels: int, channel_m
If None, sorts the channel indices in ascending y channel location and puts them at the
beginning of the zero-channel-padded recording.
"""
BaseRecording.__init__(
self, parent_recording.get_sampling_frequency(), np.arange(num_channels), parent_recording.get_dtype()
)
BaseRecording.__init__(self, recording.get_sampling_frequency(), np.arange(num_channels), recording.get_dtype())

if channel_mapping is not None:
assert (
len(channel_mapping) == parent_recording.get_num_channels()
len(channel_mapping) == recording.get_num_channels()
), "The new mapping must be specified for all channels."
assert max(channel_mapping) < num_channels, (
"The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording."
)
else:
if (
"locations" in parent_recording.get_property_keys()
or "contact_vector" in parent_recording.get_property_keys()
):
self.channel_mapping = np.argsort(parent_recording.get_channel_locations()[:, 1])
if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys():
self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1])
else:
self.channel_mapping = np.arange(parent_recording.get_num_channels())
self.channel_mapping = np.arange(recording.get_num_channels())

self.parent_recording = parent_recording
self.parent_recording = recording
self.num_channels = num_channels
for segment in parent_recording._recording_segments:
for segment in recording._recording_segments:
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping)
self.add_recording_segment(recording_segment)

# only copy relevant metadata and properties
parent_recording.copy_metadata(self, only_main=True)
self._parent = parent_recording
prop_keys = parent_recording.get_property_keys()
recording.copy_metadata(self, only_main=True)
self._parent = recording
prop_keys = recording.get_property_keys()

for k in prop_keys:
values = self.get_property(k)
if values is not None:
self.set_property(k, values, ids=self.channel_ids[self.channel_mapping])

self._kwargs = dict(
parent_recording=parent_recording, num_channels=num_channels, channel_mapping=channel_mapping
)
self._kwargs = dict(parent_recording=recording, num_channels=num_channels, channel_mapping=channel_mapping)


class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
BasePreprocessorSegment.__init__(self, recording_segment)
self.parent_recording_segment = recording_segment
self.num_channels = num_channels
self.channel_mapping = channel_mapping

Expand Down
Loading

0 comments on commit 8f673b5

Please sign in to comment.