diff --git a/.github/workflows/ci-test.yml b/.github/workflows/ci-test.yml new file mode 100644 index 0000000..6524d3c --- /dev/null +++ b/.github/workflows/ci-test.yml @@ -0,0 +1,34 @@ +name: Testing pipeline + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +concurrency: # Cancel previous workflows on the same pull request + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + name: Test on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest", "macos-latest", "windows-latest"] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install -U pip # Official recommended way + pip install pytest + pip install -e . + - name: Test pipeline with pytest + run: | + pytest -v + shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/pypi_release.yaml b/.github/workflows/pypi_release.yaml new file mode 100644 index 0000000..080bcaf --- /dev/null +++ b/.github/workflows/pypi_release.yaml @@ -0,0 +1,31 @@ +name: Build and publish spikeinterface_pipelines Pyton package + +on: + push: + branches: + - main + paths: + - src/spikeinterface_pipelines/** + +jobs: + pypi-release: + name: PyPI release + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + pip install twine + - name: Build and publish to PyPI + run: | + python -m build + twine upload dist/* + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/in_progress/hd_shank_preprocessing.py b/in_progress/hd_shank_preprocessing.py deleted file mode 100644 index 45eabc3..0000000 --- a/in_progress/hd_shank_preprocessing.py +++ /dev/null @@ -1,151 +0,0 @@ -from pathlib import Path -import numpy as np -from pydantic import BaseModel, Field -import spikeinterface as si -import spikeinterface.preprocessing as spre - - -######################################################################################################################## -# Preprocessing parameters - -class PhaseShiftParameters(BaseModel): - margin_ms: float = Field(default=100.0, description="Margin in ms to use for phase shift") - -class HighpassFilterParameters(BaseModel): - freq_min: float = Field(default=300.0, description="Minimum frequency in Hz") - margin_ms: float = Field(default=5.0, description="Margin in ms to use for highpass filter") - -class DetectBadChannelsParameters(BaseModel): - method: str = Field(default="coherence+psd", description="Method to use for detecting bad channels: 'coherence+psd' or ...") - dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channels") - noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channels") - outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channels") - n_neighbors: int = Field(default=11, description="Number of neighbors to use for bad channel detection") - seed: int = Field(default=0, description="Seed for random number generator") - -class CommonReferenceParameters(BaseModel): - reference: str = Field(default="global", description="Reference to use for common reference: 'global' or ...") - operator: str = Field(default="median", description="Operator to use for common reference: 'median' or ...") - -class HighpassSpatialFilterParameters(BaseModel): - n_channel_pad: int = Field(default=60, description="Number of channels to pad") - n_channel_taper: int = Field(default=None, description="Number of channels to taper") - direction: str = Field(default="y", description="Direction to use for highpass spatial filter: 'y' or ...") - apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") - agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for automatic gain control") - highpass_butter_order: int = Field(default=3, description="Butterworth order for highpass filter") - highpass_butter_wn: float = Field(default=0.01, description="Butterworth wn for highpass filter") - -class HDShankPreprocessingParameters(BaseModel): - preprocessing_strategy: str = Field(default="cmr", description="Preprocessing strategy to use: destripe or cmr") - highpass_filter: HighpassFilterParameters = Field(default_factory=HighpassFilterParameters, description="Highpass filter parameters") - phase_shift: PhaseShiftParameters = Field(default_factory=PhaseShiftParameters, description="Phase shift parameters") - detect_bad_channels: DetectBadChannelsParameters = Field(default_factory=DetectBadChannelsParameters, description="Detect bad channels parameters") - remove_out_channels: bool = Field(default=True, description="Whether to remove out channels") - remove_bad_channels: bool = Field(default=True, description="Whether to remove bad channels") - max_bad_channel_fraction_to_remove: float = Field(default=0.5, description="Maximum fraction of bad channels to remove") - common_reference: CommonReferenceParameters = Field(default_factory=CommonReferenceParameters, description="Common reference parameters") - highpass_spatial_filter: HighpassSpatialFilterParameters = Field(default_factory=HighpassSpatialFilterParameters, description="Highpass spatial filter parameters") - -######################################################################################################################## - -def hd_shank_preprocessing( - recording: si.BaseRecording, - params: HDShankPreprocessingParameters, - preprocessed_output_folder: Path, - verbose: bool = False -): - if "inter_sample_shift" in recording.get_property_keys(): - recording_ps_full = spre.phase_shift( - recording, - margin_ms=params.phase_shift.margin_ms - ) - else: - recording_ps_full = recording - - recording_hp_full = spre.highpass_filter( - recording_ps_full, - freq_min=params.highpass_filter.freq_min, - margin_ms=params.highpass_filter.margin_ms - ) - - # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels( - recording_hp_full, - method=params.detect_bad_channels.method, - dead_channel_threshold=params.detect_bad_channels.dead_channel_threshold, - noisy_channel_threshold=params.detect_bad_channels.noisy_channel_threshold, - outside_channel_threshold=params.detect_bad_channels.outside_channel_threshold, - n_neighbors=params.detect_bad_channels.n_neighbors, - seed=params.detect_bad_channels.seed - ) - - dead_channel_mask = channel_labels == "dead" - noise_channel_mask = channel_labels == "noise" - out_channel_mask = channel_labels == "out" - - if verbose: - print("\tBad channel detection:") - print( - f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}" - ) - dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] - noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] - out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] - - all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - - max_bad_channel_fraction_to_remove = params.max_bad_channel_fraction_to_remove - - # skip_processing = False - if len(all_bad_channel_ids) >= int( - max_bad_channel_fraction_to_remove * recording.get_num_channels() - ): - # always print this message even if verbose is False? - print( - f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " - f"Skipping further processing for this recording." - ) - # skip_processing = True - recording_ret = recording_hp_full - else: - if params.remove_out_channels: - if verbose: - print(f"\tRemoving {len(out_channel_ids)} out channels") - recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) - else: - recording_rm_out = recording_hp_full - - recording_processed_cmr = spre.common_reference( - recording_rm_out, - reference=params.common_reference.reference, - operator=params.common_reference.operator - ) - - bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) - recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) - recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, - n_channel_pad=params.highpass_spatial_filter.n_channel_pad, - n_channel_taper=params.highpass_spatial_filter.n_channel_taper, - direction=params.highpass_spatial_filter.direction, - apply_agc=params.highpass_spatial_filter.apply_agc, - agc_window_length_s=params.highpass_spatial_filter.agc_window_length_s, - highpass_butter_order=params.highpass_spatial_filter.highpass_butter_order, - highpass_butter_wn=params.highpass_spatial_filter.highpass_butter_wn, - ) - - preproc_strategy = params.preprocessing_strategy - if preproc_strategy == "cmr": - recording_processed = recording_processed_cmr - else: - recording_processed = recording_hp_spatial - - if params.remove_bad_channels: - if verbose: - print(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing") - recording_processed = recording_processed.remove_channels(bad_channel_ids) - recording_name = 'recording' # not sure what this should be - recording_saved = recording_processed.save(folder=preprocessed_output_folder / recording_name) - recording_ret = recording_saved - return recording_ret diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..827aec2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "spikeinterface_pipelines" +version = "0.0.2" +description = "Collection of standardized analysis pipelines based on SpikeInterfacee." +readme = "README.md" +authors = [ + { name = "Alessio Buccino", email = "alessiop.buccino@gmail.com" }, + { name = "Jeremy Magland", email = "jmagland@flatironinstitute.org" }, + { name = "Luiz Tauffer", email = "luiz.tauffer@catalystneuro.com" }, +] +requires-python = ">=3.8" +dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"] +keywords = [ + "spikeinterface", + "spike sorting", + "electrophysiology", + "neuroscience", +] + +[project.urls] +homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines" +documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines" +repository = "https://github.com/SpikeInterface/spikeinterface_pipelines" + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.black] +line-length = 120 diff --git a/src/spikeinterface_pipelines/__init__.py b/src/spikeinterface_pipelines/__init__.py new file mode 100644 index 0000000..8d9a59d --- /dev/null +++ b/src/spikeinterface_pipelines/__init__.py @@ -0,0 +1 @@ +from .pipeline import run_pipeline diff --git a/src/spikeinterface_pipelines/global_params.py b/src/spikeinterface_pipelines/global_params.py new file mode 100644 index 0000000..73c79f6 --- /dev/null +++ b/src/spikeinterface_pipelines/global_params.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel, Field + + +class JobKwargs(BaseModel): + n_jobs: int = Field(default=-1, description="The number of jobs to run in parallel.") + chunk_duration: str = Field(default="1s", description="The duration of the chunks to process.") + progress_bar: bool = Field(default=False, description="Whether to display a progress bar.") diff --git a/src/spikeinterface_pipelines/logger.py b/src/spikeinterface_pipelines/logger.py new file mode 100644 index 0000000..0c8a7ea --- /dev/null +++ b/src/spikeinterface_pipelines/logger.py @@ -0,0 +1,4 @@ +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py new file mode 100644 index 0000000..44d140e --- /dev/null +++ b/src/spikeinterface_pipelines/pipeline.py @@ -0,0 +1,75 @@ +from pathlib import Path +import re +from typing import Tuple + +import spikeinterface as si + +from .logger import logger +from .global_params import JobKwargs +from .preprocessing import preprocess, PreprocessingParams +from .spikesorting import spikesort, SpikeSortingParams +from .postprocessing import postprocess, PostprocessingParams + + +# TODO - WIP +def run_pipeline( + recording: si.BaseRecording, + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/"), + job_kwargs: JobKwargs = JobKwargs(), + preprocessing_params: PreprocessingParams = PreprocessingParams(), + spikesorting_params: SpikeSortingParams = SpikeSortingParams(), + postprocessing_params: PostprocessingParams = PostprocessingParams(), + run_preprocessing: bool = True, +) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: + # Create folders + scratch_folder.mkdir(exist_ok=True, parents=True) + results_folder.mkdir(exist_ok=True, parents=True) + + # Paths + results_folder_preprocessing = results_folder / "preprocessing" + results_folder_spikesorting = results_folder / "spikesorting" + results_folder_postprocessing = results_folder / "postprocessing" + + # set global job kwargs + si.set_global_job_kwargs(**job_kwargs.model_dump()) + + # Preprocessing + if run_preprocessing: + logger.info("Preprocessing recording") + recording_preprocessed = preprocess( + recording=recording, + preprocessing_params=preprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_preprocessing, + ) + if recording_preprocessed is None: + raise Exception("Preprocessing failed") + else: + logger.info("Skipping preprocessing") + recording_preprocessed = recording + + # Spike Sorting + sorting = spikesort( + recording=recording_preprocessed, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + results_folder=results_folder_spikesorting, + ) + if sorting is None: + raise Exception("Spike sorting failed") + + # Postprocessing + waveform_extractor = postprocess( + recording=recording_preprocessed, + sorting=sorting, + postprocessing_params=postprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_postprocessing, + ) + + # TODO: Curation + + # TODO: Visualization + + return (recording_preprocessed, sorting, waveform_extractor) diff --git a/src/spikeinterface_pipelines/postprocessing/__init__.py b/src/spikeinterface_pipelines/postprocessing/__init__.py new file mode 100644 index 0000000..5359475 --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/__init__.py @@ -0,0 +1,2 @@ +from .postprocessing import postprocess +from .params import PostprocessingParams diff --git a/src/spikeinterface_pipelines/postprocessing/params.py b/src/spikeinterface_pipelines/postprocessing/params.py new file mode 100644 index 0000000..1f81a8f --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -0,0 +1,179 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Tuple +from enum import Enum + + +class PresenceRatio(BaseModel): + bin_duration_s: float = Field(default=60, description="Duration of the bin in seconds.") + + +class SNR(BaseModel): + peak_sign: str = Field(default="neg", description="Sign of the peak.") + peak_mode: str = Field(default="extremum", description="Mode of the peak.") + random_chunk_kwargs_dict: Optional[dict] = Field(default=None, description="Random chunk arguments.") + + +class ISIViolation(BaseModel): + isi_threshold_ms: float = Field(default=1.5, description="ISI threshold in milliseconds.") + min_isi_ms: float = Field(default=0.0, description="Minimum ISI in milliseconds.") + + +class RPViolation(BaseModel): + refractory_period_ms: float = Field(default=1.0, description="Refractory period in milliseconds.") + censored_period_ms: float = Field(default=0.0, description="Censored period in milliseconds.") + + +class SlidingRPViolation(BaseModel): + bin_size_ms: float = Field( + default=0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25." + ) + window_size_s: float = Field(default=1, description="Window in seconds to compute correlogram, by default 1.") + exclude_ref_period_below_ms: float = Field( + default=0.5, description="Refractory periods below this value are excluded, by default 0.5" + ) + max_ref_period_ms: float = Field( + default=10, description="Maximum refractory period to test in ms, by default 10 ms." + ) + contamination_values: Optional[list] = Field( + default=None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %" + ) + + +class PeakSign(str, Enum): + neg = "neg" + pos = "pos" + both = "both" + + +class AmplitudeCutoff(BaseModel): + peak_sign: PeakSign = Field(default="neg", description="The sign of the peaks.") + num_histogram_bins: int = Field( + default=100, description="The number of bins to use to compute the amplitude histogram." + ) + histogram_smoothing_value: int = Field( + default=3, description="Controls the smoothing applied to the amplitude histogram." + ) + amplitudes_bins_min_ratio: int = Field( + default=5, + description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.", + ) + + +class AmplitudeMedian(BaseModel): + peak_sign: PeakSign = Field(default="neg", description="The sign of the peaks.") + + +class NearestNeighbor(BaseModel): + max_spikes: int = Field( + default=10000, + description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.", + ) + min_spikes: int = Field(default=10, description="Minimum number of spikes.") + n_neighbors: int = Field(default=4, description="The number of neighbors to use.") + + +class NNIsolation(NearestNeighbor): + n_components: int = Field(default=10, description="The number of PC components to use to project the snippets to.") + radius_um: int = Field( + default=100, description="The radius, in um, that channels need to be within the peak channel to be included." + ) + + +class QMParams(BaseModel): + presence_ratio: PresenceRatio = Field(default=PresenceRatio(), description="Presence ratio.") + snr: SNR = Field(default=SNR(), description="Signal to noise ratio.") + isi_violation: ISIViolation = Field(default=ISIViolation(), description="ISI violation.") + rp_violation: RPViolation = Field(default=RPViolation(), description="Refractory period violation.") + sliding_rp_violation: SlidingRPViolation = Field( + default=SlidingRPViolation(), description="Sliding refractory period violation." + ) + amplitude_cutoff: AmplitudeCutoff = Field(default=AmplitudeCutoff(), description="Amplitude cutoff.") + amplitude_median: AmplitudeMedian = Field(default=AmplitudeMedian(), description="Amplitude median.") + nearest_neighbor: NearestNeighbor = Field(default=NearestNeighbor(), description="Nearest neighbor.") + nn_isolation: NNIsolation = Field(default=NNIsolation(), description="Nearest neighbor isolation.") + nn_noise_overlap: NNIsolation = Field(default=NNIsolation(), description="Nearest neighbor noise overlap.") + + +class QualityMetrics(BaseModel): + qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") + metric_names: List[str] = Field(default=None, description="List of metric names to compute.") + n_jobs: int = Field(default=1, description="Number of jobs.") + + +class Sparsity(BaseModel): + method: str = Field(default="radius", description="Method for determining sparsity.") + radius_um: int = Field(default=100, description="Radius in micrometers for sparsity.") + + +class WaveformsRaw(BaseModel): + ms_before: float = Field(default=1.0, description="Milliseconds before") + ms_after: float = Field(default=2.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(default=100, description="Maximum spikes per unit") + return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field( + default=("average", "std"), description="Precomputation template method" + ) + use_relative_path: bool = Field(default=True, description="Use relative paths") + + +class Waveforms(BaseModel): + ms_before: float = Field(default=3.0, description="Milliseconds before") + ms_after: float = Field(default=4.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(default=500, description="Maximum spikes per unit") + return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field( + default=("average", "std"), description="Precomputation template method" + ) + use_relative_path: bool = Field(default=True, description="Use relative paths") + + +class SpikeAmplitudes(BaseModel): + peak_sign: str = Field(default="neg", description="Sign of the peak") + return_scaled: bool = Field(default=True, description="Flag to determine if amplitudes should be scaled") + outputs: str = Field(default="concatenated", description="Output format for the spike amplitudes") + + +class Similarity(BaseModel): + method: str = Field(default="cosine_similarity", description="Method to compute similarity") + + +class Correlograms(BaseModel): + window_ms: float = Field(default=100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(default=2.0, description="Size of the bin in milliseconds") + + +class ISIS(BaseModel): + window_ms: float = Field(default=100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(default=5.0, description="Size of the bin in milliseconds") + + +class Locations(BaseModel): + method: str = Field(default="monopolar_triangulation", description="Method to determine locations") + + +class TemplateMetrics(BaseModel): + upsampling_factor: int = Field(default=10, description="Upsampling factor") + sparsity: Optional[str] = Field(default=None, description="Sparsity method") + + +class PrincipalComponents(BaseModel): + n_components: int = Field(default=5, description="Number of principal components") + mode: str = Field(default="by_channel_local", description="Mode of principal component analysis") + whiten: bool = Field(default=True, description="Whiten the components") + + +class PostprocessingParams(BaseModel): + sparsity: Sparsity = Field(default=Sparsity(), description="Sparsity") + waveforms_raw: WaveformsRaw = Field(default=WaveformsRaw(), description="Waveforms raw") + waveforms: Waveforms = Field(default=Waveforms(), description="Waveforms") + spike_amplitudes: SpikeAmplitudes = Field(default=SpikeAmplitudes(), description="Spike amplitudes") + similarity: Similarity = Field(default=Similarity(), description="Similarity") + correlograms: Correlograms = Field(default=Correlograms(), description="Correlograms") + isis: ISIS = Field(default=ISIS(), description="ISIS") + locations: Locations = Field(default=Locations(), description="Locations") + template_metrics: TemplateMetrics = Field(default=TemplateMetrics(), description="Template metrics") + principal_components: PrincipalComponents = Field(default=PrincipalComponents(), description="Principal components") + quality_metrics: QualityMetrics = Field(default=QualityMetrics(), description="Quality metrics") + duplicate_threshold: float = Field(default=0.9, description="Duplicate threshold") diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py new file mode 100644 index 0000000..60d21b5 --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -0,0 +1,107 @@ +import warnings +from pathlib import Path +import shutil + +import spikeinterface as si +import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm +import spikeinterface.curation as sc + +from .params import PostprocessingParams +from ..logger import logger + + +warnings.filterwarnings("ignore") + + +def postprocess( + recording: si.BaseRecording, + sorting: si.BaseSorting, + postprocessing_params: PostprocessingParams = PostprocessingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/postprocessing/"), +) -> si.WaveformExtractor: + """ + Postprocess preprocessed and spike sorting output + + Parameters + ---------- + recording: si.BaseRecording + The input recording + sorting: si.BaseSorting + The input sorting + postprocessing_params: PostprocessingParams + Postprocessing parameters + results_folder: Path + Path to the results folder + + Returns + ------- + si.WaveformExtractor + The waveform extractor + """ + + tmp_folder = scratch_folder / "tmp_postprocessing" + tmp_folder.mkdir(parents=True, exist_ok=True) + + # first extract some raw waveforms in memory to deduplicate based on peak alignment + wf_dedup_folder = tmp_folder / "waveforms_dense" + waveform_extractor_raw = si.extract_waveforms( + recording, sorting, folder=wf_dedup_folder, sparse=False, **postprocessing_params.waveforms_raw.model_dump() + ) + + # de-duplication + sorting_deduplicated = sc.remove_redundant_units( + waveform_extractor_raw, duplicate_threshold=postprocessing_params.duplicate_threshold + ) + logger.info( + f"[Postprocessing] \tNumber of original units: {len(waveform_extractor_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}" + ) + deduplicated_unit_ids = sorting_deduplicated.unit_ids + + # use existing deduplicated waveforms to compute sparsity + sparsity_raw = si.compute_sparsity(waveform_extractor_raw, **postprocessing_params.sparsity.model_dump()) + sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] + sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) + + # this is a trick to make the postprocessed folder "self-contained + sorting_folder = results_folder / "sorting" + sorting_deduplicated = sorting_deduplicated.save(folder=sorting_folder) + + # now extract waveforms on de-duplicated units + logger.info("[Postprocessing] \tSaving sparse de-duplicated waveform extractor folder") + waveform_extractor = si.extract_waveforms( + recording, + sorting_deduplicated, + folder=results_folder / "waveforms", + sparsity=sparsity, + sparse=True, + overwrite=True, + **postprocessing_params.waveforms.model_dump(), + ) + + logger.info("[Postprocessing] \tComputing spike amplitides") + _ = spost.compute_spike_amplitudes(waveform_extractor, **postprocessing_params.spike_amplitudes.model_dump()) + logger.info("[Postprocessing] \tComputing unit locations") + _ = spost.compute_unit_locations(waveform_extractor, **postprocessing_params.locations.model_dump()) + logger.info("[Postprocessing] \tComputing spike locations") + _ = spost.compute_spike_locations(waveform_extractor, **postprocessing_params.locations.model_dump()) + logger.info("[Postprocessing] \tComputing correlograms") + _ = spost.compute_correlograms(waveform_extractor, **postprocessing_params.correlograms.model_dump()) + logger.info("[Postprocessing] \tComputing ISI histograms") + _ = spost.compute_isi_histograms(waveform_extractor, **postprocessing_params.isis.model_dump()) + logger.info("[Postprocessing] \tComputing template similarity") + _ = spost.compute_template_similarity(waveform_extractor, **postprocessing_params.similarity.model_dump()) + logger.info("[Postprocessing] \tComputing template metrics") + _ = spost.compute_template_metrics(waveform_extractor, **postprocessing_params.template_metrics.model_dump()) + logger.info("[Postprocessing] \tComputing PCA") + _ = spost.compute_principal_components( + waveform_extractor, **postprocessing_params.principal_components.model_dump() + ) + logger.info("[Postprocessing] \tComputing quality metrics") + _ = sqm.compute_quality_metrics(waveform_extractor, **postprocessing_params.quality_metrics.model_dump()) + + # cleanup + shutil.rmtree(tmp_folder) + + return waveform_extractor diff --git a/src/spikeinterface_pipelines/preprocessing/__init__.py b/src/spikeinterface_pipelines/preprocessing/__init__.py new file mode 100644 index 0000000..ad522e5 --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from .preprocessing import preprocess +from .params import PreprocessingParams diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py new file mode 100644 index 0000000..b5bc566 --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -0,0 +1,64 @@ +from pydantic import BaseModel, Field +from typing import Optional +from enum import Enum + + +class PreprocessingStrategy(str, Enum): + cmr = "cmr" + destripe = "destripe" + + +class HighpassFilter(BaseModel): + freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") + margin_ms: float = Field(default=5.0, description="Margin in milliseconds") + + +class PhaseShift(BaseModel): + margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift") + + +class DetectBadChannels(BaseModel): + method: str = Field(default="coherence+psd", description="Method to detect bad channels") + dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channel") + noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channel") + outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channel") + n_neighbors: int = Field(default=11, description="Number of neighbors") + seed: int = Field(default=0, description="Seed value") + + +class CommonReference(BaseModel): + reference: str = Field(default="global", description="Type of reference") + operator: str = Field(default="median", description="Operator used for common reference") + + +class HighpassSpatialFilter(BaseModel): + n_channel_pad: int = Field(default=60, description="Number of channels to pad") + n_channel_taper: Optional[int] = Field(default=None, description="Number of channels to taper") + direction: str = Field(default="y", description="Direction for the spatial filter") + apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") + agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for AGC") + highpass_butter_order: int = Field(default=3, description="Order for the Butterworth filter") + highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") + + +class MotionCorrection(BaseModel): + compute: bool = Field(default=True, description="Whether to compute motion correction") + apply: bool = Field(default=False, description="Whether to apply motion correction") + preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") + + +class PreprocessingParams(BaseModel): + preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") + highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") + phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") + common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") + highpass_spatial_filter: HighpassSpatialFilter = Field( + default=HighpassSpatialFilter(), description="Highpass spatial filter" + ) + motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction") + detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels") + remove_out_channels: bool = Field(default=True, description="Flag to remove out channels") + remove_bad_channels: bool = Field(default=True, description="Flag to remove bad channels") + max_bad_channel_fraction_to_remove: float = Field( + default=0.5, description="Maximum fraction of bad channels to remove" + ) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py new file mode 100644 index 0000000..ba05529 --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -0,0 +1,114 @@ +import warnings +import numpy as np +from pathlib import Path + +import spikeinterface as si +import spikeinterface.preprocessing as spre + +from ..logger import logger +from .params import PreprocessingParams + + +warnings.filterwarnings("ignore") + + +def preprocess( + recording: si.BaseRecording, + preprocessing_params: PreprocessingParams = PreprocessingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/preprocessing/"), +) -> si.BaseRecording: + """ + Apply preprocessing to recording. + + Parameters + ---------- + recording: si.BaseRecording + The input recording + preprocessing_params: PreprocessingParams + Preprocessing parameters + scratch_folder: Path + Path to the scratch folder + results_folder: Path + Path to the results folder + + Returns + ------- + si.BaseRecording | None + Preprocessed recording. If more than `max_bad_channel_fraction_to_remove` channels are detected as bad, + returns None. + """ + logger.info("[Preprocessing] \tRunning Preprocessing stage") + logger.info(f"[Preprocessing] \tDuration: {np.round(recording.get_total_duration(), 2)} s") + + # Phase shift correction + if "inter_sample_shift" in recording.get_property_keys(): + logger.info(f"[Preprocessing] \tPhase shift") + recording = spre.phase_shift(recording, **preprocessing_params.phase_shift.model_dump()) + else: + logger.info(f"[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") + + # Highpass filter + recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) + + # Detect and remove bad channels + _, channel_labels = spre.detect_bad_channels( + recording_hp_full, **preprocessing_params.detect_bad_channels.model_dump() + ) + dead_channel_mask = channel_labels == "dead" + noise_channel_mask = channel_labels == "noise" + out_channel_mask = channel_labels == "out" + logger.info( + f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels" + ) + dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] + noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] + out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] + all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) + + max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove + if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): + logger.info( + f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " + ) + logger.info("[Preprocessing] \tSkipping further processing for this recording.") + return None + + if preprocessing_params.remove_out_channels: + logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") + recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) + else: + recording_rm_out = recording_hp_full + + bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) + + # Denoise: CMR or destripe + if preprocessing_params.preprocessing_strategy == "cmr": + recording_processed = spre.common_reference( + recording_rm_out, **preprocessing_params.common_reference.model_dump() + ) + else: + recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) + recording_processed = spre.highpass_spatial_filter( + recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump() + ) + + if preprocessing_params.remove_bad_channels: + logger.info( + f"[Preprocessing] \tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing" + ) + recording_processed = recording_processed.remove_channels(bad_channel_ids) + + # Motion correction + if preprocessing_params.motion_correction.compute: + preset = preprocessing_params.motion_correction.preset + logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") + motion_folder = results_folder / "motion_correction" + recording_corrected = spre.correct_motion( + recording_processed, preset=preset, folder=motion_folder, verbose=False + ) + if preprocessing_params.motion_correction.apply: + logger.info("[Preprocessing] \tApplying motion correction") + recording_processed = recording_corrected + + return recording_processed diff --git a/src/spikeinterface_pipelines/spikesorting/__init__.py b/src/spikeinterface_pipelines/spikesorting/__init__.py new file mode 100644 index 0000000..9bdd5fa --- /dev/null +++ b/src/spikeinterface_pipelines/spikesorting/__init__.py @@ -0,0 +1,2 @@ +from .spikesorting import spikesort +from .params import SpikeSortingParams diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py new file mode 100644 index 0000000..ab5fd57 --- /dev/null +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field +from typing import Union, List +from enum import Enum + + +class SorterName(str, Enum): + ironclust = "ironclust" + kilosort25 = "kilosort2_5" + kilosort3 = "kilosort3" + mountainsort5 = "mountainsort5" + + +class Kilosort25Model(BaseModel): + detect_threshold: float = Field(default=6, description="Threshold for spike detection") + projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") + preclust_threshold: float = Field( + default=8, description="Threshold crossings for pre-clustering (in PCA projection space)" + ) + car: bool = Field(default=True, description="Enable or disable common reference") + minFR: float = Field( + default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed" + ) + minfr_goodchannels: float = Field(default=0.1, description="Minimum firing rate on a 'good' channel") + nblocks: int = Field( + default=5, + description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.", + ) + sig: float = Field(default=20, description="spatial smoothness constant for registration") + freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") + sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + nPCs: int = Field(default=3, description="Number of PCA dimensions") + ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") + nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") + NT: int = Field(default=None, description="Batch size (if None it is automatically computed)") + AUCsplit: float = Field( + default=0.9, + description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step", + ) + do_correction: bool = Field(default=True, description="If True drift registration is applied") + wave_length: float = Field( + default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)" + ) + keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + skip_kilosort_preprocessing: bool = Field( + default=False, description="Can optionaly skip the internal kilosort preprocessing" + ) + scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.") + + +class Kilosort3Model(BaseModel): + pass + + +class IronClustModel(BaseModel): + pass + + +class MountainSort5Model(BaseModel): + pass + + +class SpikeSortingParams(BaseModel): + sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") + sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( + default=Kilosort25Model(), description="Sorter specific kwargs." + ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py new file mode 100644 index 0000000..831e846 --- /dev/null +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -0,0 +1,63 @@ +from pathlib import Path +import shutil + +import spikeinterface as si +import spikeinterface.sorters as ss +import spikeinterface.curation as sc + +from ..logger import logger +from .params import SpikeSortingParams + + +def spikesort( + recording: si.BaseRecording, + spikesorting_params: SpikeSortingParams = SpikeSortingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/spikesorting/"), +) -> si.BaseSorting | None: + """ + Apply spike sorting to recording + + Parameters + ---------- + recording: si.BaseRecording + The input recording + sorting_params: SpikeSortingParams + Spike sorting parameters + scratch_folder: Path + Path to the scratch folder + results_folder: Path + Path to the results folder + + Returns + ------- + si.BaseSorting | None + Spike sorted sorting. If spike sorting fails, None is returned + """ + output_folder = scratch_folder / "tmp_spikesorting" + + try: + logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") + sorting = ss.run_sorter( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + output_folder=str(output_folder), + verbose=False, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) + logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") + # remove spikes beyond num_Samples (if any) + sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording) + # save results + logger.info(f"[Spikesorting]\tSaving results to {results_folder}") + return sorting + except Exception as e: + # save log to results + results_folder.mkdir(exist_ok=True, parents=True) + if (output_folder).is_dir(): + shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + shutil.rmtree(output_folder) + logger.info(f"Spike sorting error:\n{e}") + return None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..59428d0 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,131 @@ +import shutil +import pytest +import numpy as np +from pathlib import Path + +import spikeinterface as si +import spikeinterface.sorters as ss + +from spikeinterface_pipelines import pipeline + +from spikeinterface_pipelines.preprocessing import preprocess +from spikeinterface_pipelines.spikesorting import spikesort +from spikeinterface_pipelines.postprocessing import postprocess + +from spikeinterface_pipelines.preprocessing.params import PreprocessingParams +from spikeinterface_pipelines.spikesorting.params import Kilosort25Model, SpikeSortingParams +from spikeinterface_pipelines.postprocessing.params import PostprocessingParams + + +def _generate_gt_recording(): + recording, sorting = si.generate_ground_truth_recording(durations=[30], num_channels=64, seed=0) + # add inter sample shift (but fake) + inter_sample_shifts = np.zeros(recording.get_num_channels()) + recording.set_property("inter_sample_shift", inter_sample_shifts) + + return recording, sorting + + +@pytest.fixture +def generate_recording(): + return _generate_gt_recording() + + +def test_preprocessing(tmp_path, generate_recording): + recording, _ = generate_recording + + results_folder = Path(tmp_path) / "results_preprocessing" + scratch_folder = Path(tmp_path) / "scratch_prepocessing" + + recording_processed = preprocess( + recording=recording, + preprocessing_params=PreprocessingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder, + ) + + assert isinstance(recording_processed, si.BaseRecording) + + +@pytest.mark.skipif(not "kilosort2_5" in ss.installed_sorters(), reason="kilosort2_5 not installed") +def test_spikesorting(tmp_path, generate_recording): + recording, _ = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results_spikesorting" + scratch_folder = Path(tmp_path) / "scratch_spikesorting" + + sorting = spikesort( + recording=recording, + spikesorting_params=SpikeSortingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder, + ) + + assert isinstance(sorting, si.BaseSorting) + + +def test_postprocessing(tmp_path, generate_recording): + recording, sorting = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results_postprocessing" + scratch_folder = Path(tmp_path) / "scratch_postprocessing" + + waveform_extractor = postprocess( + recording=recording, + sorting=sorting, + postprocessing_params=PostprocessingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder, + ) + + assert isinstance(waveform_extractor, si.WaveformExtractor) + + +@pytest.mark.skipif(not "kilosort2_5" in ss.installed_sorters(), reason="kilosort2_5 not installed") +def test_pipeline(tmp_path, generate_recording): + recording, _ = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results" + scratch_folder = Path(tmp_path) / "scratch" + + ks25_params = Kilosort25Model(do_correction=False) + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + ) + + recording_processed, sorting, waveform_extractor = pipeline.run_pipeline( + recording=recording, + results_folder=results_folder, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + ) + + assert isinstance(recording_processed, si.BaseRecording) + assert isinstance(sorting, si.BaseSorting) + assert isinstance(waveform_extractor, si.WaveformExtractor) + + +if __name__ == "__main__": + tmp_folder = Path("./tmp_pipeline_output") + if tmp_folder.is_dir(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir() + + recording, sorting = _generate_gt_recording() + + print("TEST PREPROCESSING") + test_preprocessing(tmp_folder, (recording, sorting)) + print("TEST SPIKESORTING") + test_spikesorting(tmp_folder, (recording, sorting)) + print("TEST POSTPROCESSING") + test_postprocessing(tmp_folder, (recording, sorting)) + + print("TEST PIPELINE") + test_pipeline(tmp_folder, (recording, sorting))