From 39078a1e11516c01634d317fa0f777c075feb6d4 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 4 Jan 2024 15:49:16 +0100 Subject: [PATCH] types --- src/spikeinterface_pipelines/pipeline.py | 65 ++++++++++++------- .../preprocessing/preprocessing.py | 2 +- .../spikesorting/params.py | 3 +- .../spikesorting/spikesorting.py | 18 +++-- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index a24a0de..ecd2af3 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,5 +1,6 @@ +from __future__ import annotations from pathlib import Path -from typing import Tuple, Union +from typing import Tuple import spikeinterface as si @@ -12,14 +13,20 @@ def run_pipeline( recording: si.BaseRecording, - scratch_folder: Union[Path, str] = Path("./scratch/"), - results_folder: Union[Path, str] = Path("./results/"), - job_kwargs: Union[JobKwargs, dict] = JobKwargs(), - preprocessing_params: Union[PreprocessingParams, dict] = PreprocessingParams(), - spikesorting_params: Union[SpikeSortingParams, dict] = SpikeSortingParams(), - postprocessing_params: Union[PostprocessingParams, dict] = PostprocessingParams(), + scratch_folder: Path | str = Path("./scratch/"), + results_folder: Path | str = Path("./results/"), + job_kwargs: JobKwargs | dict = JobKwargs(), + preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), + spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), + postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), run_preprocessing: bool = True, -) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: + run_spikesorting: bool = True, + run_postprocessing: bool = True, +) -> Tuple[ + si.BaseRecording | None, + si.BaseSorting | None, + si.WaveformExtractor | None +]: # Create folders results_folder = Path(results_folder) scratch_folder = Path(scratch_folder) @@ -60,23 +67,33 @@ def run_pipeline( recording_preprocessed = recording # Spike Sorting - sorting = spikesort( - recording=recording_preprocessed, - scratch_folder=scratch_folder, - spikesorting_params=spikesorting_params, - results_folder=results_folder_spikesorting, - ) - if sorting is None: - raise Exception("Spike sorting failed") + if run_spikesorting: + sorting = spikesort( + recording=recording_preprocessed, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + results_folder=results_folder_spikesorting, + ) + if sorting is None: + raise Exception("Spike sorting failed") - # Postprocessing - waveform_extractor = postprocess( - recording=recording_preprocessed, - sorting=sorting, - postprocessing_params=postprocessing_params, - scratch_folder=scratch_folder, - results_folder=results_folder_postprocessing, - ) + # Postprocessing + if run_postprocessing: + logger.info("Postprocessing sorting") + waveform_extractor = postprocess( + recording=recording_preprocessed, + sorting=sorting, + postprocessing_params=postprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_postprocessing, + ) + else: + logger.info("Skipping postprocessing") + waveform_extractor = None + else: + logger.info("Skipping spike sorting") + sorting = None + waveform_extractor = None # TODO: Curation diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 8498e57..e65a808 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -72,7 +72,7 @@ def preprocess( f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " ) logger.info("[Preprocessing] \tSkipping further processing for this recording.") - return None + return recording_hp_full if preprocessing_params.remove_out_channels: logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index ab5fd57..627687f 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -28,6 +28,7 @@ class Kilosort25Model(BaseModel): sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)") nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") @@ -40,7 +41,7 @@ class Kilosort25Model(BaseModel): wave_length: float = Field( default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)" ) - keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned") skip_kilosort_preprocessing: bool = Field( default=False, description="Can optionaly skip the internal kilosort preprocessing" ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index 069bca3..d9243c1 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,8 +1,7 @@ +from __future__ import annotations from pathlib import Path import shutil -from typing import Union -import spikeinterface as si -import spikeinterface.sorters as ss +import spikeinterface.full as si import spikeinterface.curation as sc from ..logger import logger @@ -14,7 +13,7 @@ def spikesort( spikesorting_params: SpikeSortingParams = SpikeSortingParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/spikesorting/"), -) -> Union[si.BaseSorting, None]: +) -> si.BaseSorting | None: """ Apply spike sorting to recording @@ -38,14 +37,21 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") - sorting = ss.run_sorter( + + ## TEST ONLY - REMOVE LATER ## + # si.get_default_sorter_params('kilosort2_5') + # params_kilosort2_5 = {'do_correction': False} + ## --------------------------## + + sorting = si.run_sorter( recording=recording, sorter_name=spikesorting_params.sorter_name, output_folder=str(output_folder), - verbose=False, + verbose=True, delete_output_folder=True, remove_existing_folder=True, **spikesorting_params.sorter_kwargs.model_dump(), + # **params_kilosort2_5 ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any)