From 7a57b62edf6842bdf729b2672c61aa1c12575e01 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 18 Mar 2024 12:39:07 +0100 Subject: [PATCH 1/2] update sorting models --- src/spikeinterface_pipelines/pipeline.py | 12 ++++--- .../spikesorting/params.py | 36 ++++++++++++++++--- .../spikesorting/spikesorting.py | 2 +- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 04d6f52..0f492b0 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,6 +1,5 @@ from __future__ import annotations from pathlib import Path -import re from typing import Tuple import spikeinterface as si @@ -19,7 +18,7 @@ def run_pipeline( results_folder: Path | str = Path("./results/"), job_kwargs: JobKwargs | dict = JobKwargs(), preprocessing_params: PreprocessingParams | dict = PreprocessingParams(), - spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(), + spikesorting_params: SpikeSortingParams | dict = dict(), postprocessing_params: PostprocessingParams | dict = PostprocessingParams(), curation_params: CurationParams | dict = CurationParams(), visualization_params: VisualizationParams | dict = VisualizationParams(), @@ -54,7 +53,10 @@ def run_pipeline( if isinstance(preprocessing_params, dict): preprocessing_params = PreprocessingParams(**preprocessing_params) if isinstance(spikesorting_params, dict): - spikesorting_params = SpikeSortingParams(**spikesorting_params) + spikesorting_params = SpikeSortingParams( + sorter_name=spikesorting_params['sorter_name'], + sorter_kwargs=spikesorting_params['sorter_kwargs'] + ) if isinstance(postprocessing_params, dict): postprocessing_params = PostprocessingParams(**postprocessing_params) if isinstance(curation_params, dict): @@ -117,13 +119,13 @@ def run_pipeline( else: logger.info("Skipping postprocessing") waveform_extractor = None - + else: logger.info("Skipping spike sorting") sorting = None waveform_extractor = None sorting_curated = None - + # Visualization visualization_output = None diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 627687f..50278eb 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Union, List from enum import Enum @@ -11,6 +11,7 @@ class SorterName(str, Enum): class Kilosort25Model(BaseModel): + model_config = ConfigDict(extra='forbid') detect_threshold: float = Field(default=6, description="Threshold for spike detection") projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") preclust_threshold: float = Field( @@ -49,19 +50,46 @@ class Kilosort25Model(BaseModel): class Kilosort3Model(BaseModel): + model_config = ConfigDict(extra='forbid') pass class IronClustModel(BaseModel): + model_config = ConfigDict(extra='forbid') pass class MountainSort5Model(BaseModel): - pass + model_config = ConfigDict(extra='forbid') + scheme: str = Field( + default='2', + description="Sorting scheme", + json_schema_extra={'options': ["1", "2", "3"]} + ) + detect_threshold: float = Field(default=5.5, description="Threshold for spike detection") + detect_sign: int = Field(default=-1, description="Sign of the peak") + detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds") + snippet_T1: int = Field(default=20, description="Snippet T1") + snippet_T2: int = Field(default=20, description="Snippet T2") + npca_per_channel: int = Field(default=3, description="Number of PCA per channel") + npca_per_subdivision: int = Field(default=10, description="Number of PCA per subdivision") + snippet_mask_radius: int = Field(default=250, description="Snippet mask radius") + scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius") + scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius") + scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius") + scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch") + scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds") + scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode") + scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds") + freq_min: int = Field(default=300, description="High-pass filter cutoff frequency") + freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency") + filter: bool = Field(default=True, description="Enable or disable filter") + whiten: bool = Field(default=True, description="Enable or disable whiten") class SpikeSortingParams(BaseModel): - sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") + sorter_name: SorterName = Field(description="Name of the sorter to use.") sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( - default=Kilosort25Model(), description="Sorter specific kwargs." + description="Sorter specific kwargs.", + union_mode='left_to_right' ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index d9243c1..3aefbdd 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -10,7 +10,7 @@ def spikesort( recording: si.BaseRecording, - spikesorting_params: SpikeSortingParams = SpikeSortingParams(), + spikesorting_params: SpikeSortingParams, scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/spikesorting/"), ) -> si.BaseSorting | None: From 04e63445922f0b667e2d778262de8cd0dddb19ef Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 18 Mar 2024 12:42:57 +0100 Subject: [PATCH 2/2] remove old comment --- src/spikeinterface_pipelines/spikesorting/spikesorting.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index 3aefbdd..d33c390 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -37,12 +37,6 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") - - ## TEST ONLY - REMOVE LATER ## - # si.get_default_sorter_params('kilosort2_5') - # params_kilosort2_5 = {'do_correction': False} - ## --------------------------## - sorting = si.run_sorter( recording=recording, sorter_name=spikesorting_params.sorter_name, @@ -51,7 +45,6 @@ def spikesort( delete_output_folder=True, remove_existing_folder=True, **spikesorting_params.sorter_kwargs.model_dump(), - # **params_kilosort2_5 ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any)