diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index b4ac6dd..ecd2af3 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pathlib import Path from typing import Tuple @@ -10,18 +11,25 @@ from .postprocessing import postprocess, PostprocessingParams -# TODO - WIP def run_pipeline( recording: si.BaseRecording, - scratch_folder: Path = Path("./scratch/"), - results_folder: Path = Path("./results/"), - job_kwargs: JobKwargs = JobKwargs(), - preprocessing_params: PreprocessingParams = PreprocessingParams(), - spikesorting_params: SpikeSortingParams = SpikeSortingParams(), - postprocessing_params: PostprocessingParams = PostprocessingParams(), + scratch_folder: Path | str = Path("./scratch/"), + results_folder: Path | str = Path("./results/"), + job_kwargs: JobKwargs | dict = JobKwargs(), + preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), + spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), + postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), run_preprocessing: bool = True, -) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: + run_spikesorting: bool = True, + run_postprocessing: bool = True, +) -> Tuple[ + si.BaseRecording | None, + si.BaseSorting | None, + si.WaveformExtractor | None +]: # Create folders + results_folder = Path(results_folder) + scratch_folder = Path(scratch_folder) scratch_folder.mkdir(exist_ok=True, parents=True) results_folder.mkdir(exist_ok=True, parents=True) @@ -30,6 +38,16 @@ def run_pipeline( results_folder_spikesorting = results_folder / "spikesorting" results_folder_postprocessing = results_folder / "postprocessing" + # Arguments Models validation, in case of dict + if isinstance(job_kwargs, dict): + job_kwargs = JobKwargs(**job_kwargs) + if isinstance(preprocessing_params, dict): + preprocessing_params = PreprocessingParams(**preprocessing_params) + if isinstance(spikesorting_params, dict): + spikesorting_params = SpikeSortingParams(**spikesorting_params) + if isinstance(postprocessing_params, dict): + postprocessing_params = PostprocessingParams(**postprocessing_params) + # set global job kwargs si.set_global_job_kwargs(**job_kwargs.model_dump()) @@ -49,23 +67,33 @@ def run_pipeline( recording_preprocessed = recording # Spike Sorting - sorting = spikesort( - recording=recording_preprocessed, - scratch_folder=scratch_folder, - spikesorting_params=spikesorting_params, - results_folder=results_folder_spikesorting, - ) - if sorting is None: - raise Exception("Spike sorting failed") + if run_spikesorting: + sorting = spikesort( + recording=recording_preprocessed, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + results_folder=results_folder_spikesorting, + ) + if sorting is None: + raise Exception("Spike sorting failed") - # Postprocessing - waveform_extractor = postprocess( - recording=recording_preprocessed, - sorting=sorting, - postprocessing_params=postprocessing_params, - scratch_folder=scratch_folder, - results_folder=results_folder_postprocessing, - ) + # Postprocessing + if run_postprocessing: + logger.info("Postprocessing sorting") + waveform_extractor = postprocess( + recording=recording_preprocessed, + sorting=sorting, + postprocessing_params=postprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_postprocessing, + ) + else: + logger.info("Skipping postprocessing") + waveform_extractor = None + else: + logger.info("Skipping spike sorting") + sorting = None + waveform_extractor = None # TODO: Curation diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index ba05529..e65a808 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -43,10 +43,10 @@ def preprocess( # Phase shift correction if "inter_sample_shift" in recording.get_property_keys(): - logger.info(f"[Preprocessing] \tPhase shift") + logger.info("[Preprocessing] \tPhase shift") recording = spre.phase_shift(recording, **preprocessing_params.phase_shift.model_dump()) else: - logger.info(f"[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") + logger.info("[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") # Highpass filter recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) @@ -72,7 +72,7 @@ def preprocess( f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " ) logger.info("[Preprocessing] \tSkipping further processing for this recording.") - return None + return recording_hp_full if preprocessing_params.remove_out_channels: logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index ab5fd57..627687f 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -28,6 +28,7 @@ class Kilosort25Model(BaseModel): sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)") nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") @@ -40,7 +41,7 @@ class Kilosort25Model(BaseModel): wave_length: float = Field( default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)" ) - keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned") skip_kilosort_preprocessing: bool = Field( default=False, description="Can optionaly skip the internal kilosort preprocessing" ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index 831e846..d9243c1 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,8 +1,7 @@ +from __future__ import annotations from pathlib import Path import shutil - -import spikeinterface as si -import spikeinterface.sorters as ss +import spikeinterface.full as si import spikeinterface.curation as sc from ..logger import logger @@ -38,14 +37,21 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") - sorting = ss.run_sorter( + + ## TEST ONLY - REMOVE LATER ## + # si.get_default_sorter_params('kilosort2_5') + # params_kilosort2_5 = {'do_correction': False} + ## --------------------------## + + sorting = si.run_sorter( recording=recording, sorter_name=spikesorting_params.sorter_name, output_folder=str(output_folder), - verbose=False, + verbose=True, delete_output_folder=True, remove_existing_folder=True, **spikesorting_params.sorter_kwargs.model_dump(), + # **params_kilosort2_5 ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any)