diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 4299d199ed..9cfefa9618 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -20,6 +20,8 @@ class SilencedPeriodsRecording(BasePreprocessor): The recording extractor to silance periods list_periods: list of lists/arrays One list per segment of tuples (start_frame, end_frame) to silence + noise_levels: array + Noise levels if already computed mode: "zeros" | "noise, default: "zeros" Determines what periods are replaced by. Can be one of the following: @@ -39,9 +41,7 @@ class SilencedPeriodsRecording(BasePreprocessor): name = "silence_periods" - def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs): - import scipy.interpolate - + def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, **random_chunk_kwargs): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() @@ -67,7 +67,10 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) ), "Intervals should not overlap" if mode in ["noise"]: - noise_levels = get_noise_levels(recording, return_scaled=False, concatenated=True, **random_chunk_kwargs) + if noise_levels is None: + noise_levels = get_noise_levels( + recording, return_scaled=False, concatenated=True, **random_chunk_kwargs + ) else: noise_levels = None @@ -79,9 +82,7 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_levels) self.add_recording_segment(rec_segment) - self._kwargs = dict( - recording=recording.to_dict(), list_periods=list_periods, mode=mode, noise_levels=noise_levels - ) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_levels=noise_levels) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):