From 581d8970e6eafb483ce450777cd0db3ffb63d4fd Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 20 Dec 2023 12:27:20 +0100 Subject: [PATCH 1/4] dev --- src/spikeinterface_pipelines/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index b4ac6dd..a372f6a 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -10,7 +10,6 @@ from .postprocessing import postprocess, PostprocessingParams -# TODO - WIP def run_pipeline( recording: si.BaseRecording, scratch_folder: Path = Path("./scratch/"), From 7a0d12d5b0b72272147750eec45c826397b47f3f Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 21 Dec 2023 10:56:35 +0100 Subject: [PATCH 2/4] update models --- src/spikeinterface_pipelines/pipeline.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index a372f6a..1e5d08f 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -12,15 +12,17 @@ def run_pipeline( recording: si.BaseRecording, - scratch_folder: Path = Path("./scratch/"), - results_folder: Path = Path("./results/"), - job_kwargs: JobKwargs = JobKwargs(), - preprocessing_params: PreprocessingParams = PreprocessingParams(), - spikesorting_params: SpikeSortingParams = SpikeSortingParams(), - postprocessing_params: PostprocessingParams = PostprocessingParams(), + scratch_folder: Path | str = Path("./scratch/"), + results_folder: Path | str = Path("./results/"), + job_kwargs: JobKwargs | dict = JobKwargs(), + preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), + spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), + postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), run_preprocessing: bool = True, ) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: # Create folders + results_folder = Path(results_folder) + scratch_folder = Path(scratch_folder) scratch_folder.mkdir(exist_ok=True, parents=True) results_folder.mkdir(exist_ok=True, parents=True) @@ -29,6 +31,16 @@ def run_pipeline( results_folder_spikesorting = results_folder / "spikesorting" results_folder_postprocessing = results_folder / "postprocessing" + # Arguments Models validation, in case of dict + if isinstance(job_kwargs, dict): + job_kwargs = JobKwargs(**job_kwargs) + if isinstance(preprocessing_params, dict): + preprocessing_params = PreprocessingParams(**preprocessing_params) + if isinstance(spikesorting_params, dict): + spikesorting_params = SpikeSortingParams(**spikesorting_params) + if isinstance(postprocessing_params, dict): + postprocessing_params = PostprocessingParams(**postprocessing_params) + # set global job kwargs si.set_global_job_kwargs(**job_kwargs.model_dump()) From 9b8aa78e96b856bd4a6885b2afa690e5dfd2cc88 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 22 Dec 2023 15:40:17 +0100 Subject: [PATCH 3/4] change to union type --- src/spikeinterface_pipelines/pipeline.py | 14 +++++++------- .../preprocessing/preprocessing.py | 4 ++-- .../spikesorting/spikesorting.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 1e5d08f..a24a0de 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Tuple +from typing import Tuple, Union import spikeinterface as si @@ -12,12 +12,12 @@ def run_pipeline( recording: si.BaseRecording, - scratch_folder: Path | str = Path("./scratch/"), - results_folder: Path | str = Path("./results/"), - job_kwargs: JobKwargs | dict = JobKwargs(), - preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), - spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), - postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), + scratch_folder: Union[Path, str] = Path("./scratch/"), + results_folder: Union[Path, str] = Path("./results/"), + job_kwargs: Union[JobKwargs, dict] = JobKwargs(), + preprocessing_params: Union[PreprocessingParams, dict] = PreprocessingParams(), + spikesorting_params: Union[SpikeSortingParams, dict] = SpikeSortingParams(), + postprocessing_params: Union[PostprocessingParams, dict] = PostprocessingParams(), run_preprocessing: bool = True, ) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: # Create folders diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index ba05529..8498e57 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -43,10 +43,10 @@ def preprocess( # Phase shift correction if "inter_sample_shift" in recording.get_property_keys(): - logger.info(f"[Preprocessing] \tPhase shift") + logger.info("[Preprocessing] \tPhase shift") recording = spre.phase_shift(recording, **preprocessing_params.phase_shift.model_dump()) else: - logger.info(f"[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") + logger.info("[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") # Highpass filter recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index 831e846..069bca3 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,6 +1,6 @@ from pathlib import Path import shutil - +from typing import Union import spikeinterface as si import spikeinterface.sorters as ss import spikeinterface.curation as sc @@ -14,7 +14,7 @@ def spikesort( spikesorting_params: SpikeSortingParams = SpikeSortingParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/spikesorting/"), -) -> si.BaseSorting | None: +) -> Union[si.BaseSorting, None]: """ Apply spike sorting to recording From 39078a1e11516c01634d317fa0f777c075feb6d4 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 4 Jan 2024 15:49:16 +0100 Subject: [PATCH 4/4] types --- src/spikeinterface_pipelines/pipeline.py | 65 ++++++++++++------- .../preprocessing/preprocessing.py | 2 +- .../spikesorting/params.py | 3 +- .../spikesorting/spikesorting.py | 18 +++-- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index a24a0de..ecd2af3 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,5 +1,6 @@ +from __future__ import annotations from pathlib import Path -from typing import Tuple, Union +from typing import Tuple import spikeinterface as si @@ -12,14 +13,20 @@ def run_pipeline( recording: si.BaseRecording, - scratch_folder: Union[Path, str] = Path("./scratch/"), - results_folder: Union[Path, str] = Path("./results/"), - job_kwargs: Union[JobKwargs, dict] = JobKwargs(), - preprocessing_params: Union[PreprocessingParams, dict] = PreprocessingParams(), - spikesorting_params: Union[SpikeSortingParams, dict] = SpikeSortingParams(), - postprocessing_params: Union[PostprocessingParams, dict] = PostprocessingParams(), + scratch_folder: Path | str = Path("./scratch/"), + results_folder: Path | str = Path("./results/"), + job_kwargs: JobKwargs | dict = JobKwargs(), + preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), + spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), + postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), run_preprocessing: bool = True, -) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: + run_spikesorting: bool = True, + run_postprocessing: bool = True, +) -> Tuple[ + si.BaseRecording | None, + si.BaseSorting | None, + si.WaveformExtractor | None +]: # Create folders results_folder = Path(results_folder) scratch_folder = Path(scratch_folder) @@ -60,23 +67,33 @@ def run_pipeline( recording_preprocessed = recording # Spike Sorting - sorting = spikesort( - recording=recording_preprocessed, - scratch_folder=scratch_folder, - spikesorting_params=spikesorting_params, - results_folder=results_folder_spikesorting, - ) - if sorting is None: - raise Exception("Spike sorting failed") + if run_spikesorting: + sorting = spikesort( + recording=recording_preprocessed, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + results_folder=results_folder_spikesorting, + ) + if sorting is None: + raise Exception("Spike sorting failed") - # Postprocessing - waveform_extractor = postprocess( - recording=recording_preprocessed, - sorting=sorting, - postprocessing_params=postprocessing_params, - scratch_folder=scratch_folder, - results_folder=results_folder_postprocessing, - ) + # Postprocessing + if run_postprocessing: + logger.info("Postprocessing sorting") + waveform_extractor = postprocess( + recording=recording_preprocessed, + sorting=sorting, + postprocessing_params=postprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_postprocessing, + ) + else: + logger.info("Skipping postprocessing") + waveform_extractor = None + else: + logger.info("Skipping spike sorting") + sorting = None + waveform_extractor = None # TODO: Curation diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 8498e57..e65a808 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -72,7 +72,7 @@ def preprocess( f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " ) logger.info("[Preprocessing] \tSkipping further processing for this recording.") - return None + return recording_hp_full if preprocessing_params.remove_out_channels: logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index ab5fd57..627687f 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -28,6 +28,7 @@ class Kilosort25Model(BaseModel): sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)") nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") @@ -40,7 +41,7 @@ class Kilosort25Model(BaseModel): wave_length: float = Field( default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)" ) - keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned") skip_kilosort_preprocessing: bool = Field( default=False, description="Can optionaly skip the internal kilosort preprocessing" ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index 069bca3..d9243c1 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,8 +1,7 @@ +from __future__ import annotations from pathlib import Path import shutil -from typing import Union -import spikeinterface as si -import spikeinterface.sorters as ss +import spikeinterface.full as si import spikeinterface.curation as sc from ..logger import logger @@ -14,7 +13,7 @@ def spikesort( spikesorting_params: SpikeSortingParams = SpikeSortingParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/spikesorting/"), -) -> Union[si.BaseSorting, None]: +) -> si.BaseSorting | None: """ Apply spike sorting to recording @@ -38,14 +37,21 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") - sorting = ss.run_sorter( + + ## TEST ONLY - REMOVE LATER ## + # si.get_default_sorter_params('kilosort2_5') + # params_kilosort2_5 = {'do_correction': False} + ## --------------------------## + + sorting = si.run_sorter( recording=recording, sorter_name=spikesorting_params.sorter_name, output_folder=str(output_folder), - verbose=False, + verbose=True, delete_output_folder=True, remove_existing_folder=True, **spikesorting_params.sorter_kwargs.model_dump(), + # **params_kilosort2_5 ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any)