Skip to content

Commit

Permalink
Merge branch 'main' into spikesort-by-group
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Feb 6, 2024
2 parents 2090c24 + 2e2941a commit a513c55
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 37 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
branches:
- main

env: # For the sortingview backend
KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }}
KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }}

concurrency: # Cancel previous workflows on the same pull request
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
Expand All @@ -27,6 +31,7 @@ jobs:
run: |
python -m pip install -U pip # Official recommended way
pip install pytest
pip install pyvips
pip install -e .
- name: Test pipeline with pytest
run: |
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spikeinterface_pipelines"
version = "0.0.6"
version = "0.0.8"
description = "Collection of standardized analysis pipelines based on SpikeInterfacee."
readme = "README.md"
authors = [
Expand All @@ -9,7 +9,11 @@ authors = [
{ name = "Luiz Tauffer", email = "[email protected]" },
]
requires-python = ">=3.8"
dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"]
dependencies = [
"spikeinterface[full,widgets]>=0.100.0",
"neo>=0.12.0",
"pydantic>=2.4.2",
]
keywords = [
"spikeinterface",
"spike sorting",
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

import spikeinterface.full as si
import spikeinterface as si
import spikeinterface.curation as sc

from ..logger import logger
Expand Down
50 changes: 45 additions & 5 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
import re
from typing import Tuple
import spikeinterface as si

Expand All @@ -8,6 +9,8 @@
from .preprocessing import preprocess, PreprocessingParams
from .spikesorting import spikesort, SpikeSortingParams
from .postprocessing import postprocess, PostprocessingParams
from .curation import curate, CurationParams
from .visualization import visualize, VisualizationParams


def run_pipeline(
Expand All @@ -18,13 +21,19 @@ def run_pipeline(
preprocessing_params: PreprocessingParams | dict = PreprocessingParams(),
spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(),
postprocessing_params: PostprocessingParams | dict = PostprocessingParams(),
curation_params: CurationParams | dict = CurationParams(),
visualization_params: VisualizationParams | dict = VisualizationParams(),
run_preprocessing: bool = True,
run_spikesorting: bool = True,
run_postprocessing: bool = True,
run_curation: bool = True,
run_visualization: bool = True,
) -> Tuple[
si.BaseRecording | None,
si.BaseSorting | None,
si.WaveformExtractor | None
si.WaveformExtractor | None,
si.BaseSorting | None,
dict | None,
]:
# Create folders
results_folder = Path(results_folder)
Expand All @@ -36,6 +45,8 @@ def run_pipeline(
results_folder_preprocessing = results_folder / "preprocessing"
results_folder_spikesorting = results_folder / "spikesorting"
results_folder_postprocessing = results_folder / "postprocessing"
results_folder_curation = results_folder / "curation"
results_folder_visualization = results_folder / "visualization"

# Arguments Models validation, in case of dict
if isinstance(job_kwargs, dict):
Expand All @@ -46,6 +57,10 @@ def run_pipeline(
spikesorting_params = SpikeSortingParams(**spikesorting_params)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
if isinstance(curation_params, dict):
curation_params = CurationParams(**curation_params)
if isinstance(visualization_params, dict):
visualization_params = VisualizationParams(**visualization_params)

# set global job kwargs
si.set_global_job_kwargs(**job_kwargs.model_dump())
Expand Down Expand Up @@ -77,6 +92,7 @@ def run_pipeline(
raise Exception("Spike sorting failed")

# Postprocessing
sorting_curated = sorting
if run_postprocessing:
logger.info("Postprocessing sorting")
waveform_extractor = postprocess(
Expand All @@ -86,16 +102,40 @@ def run_pipeline(
scratch_folder=scratch_folder,
results_folder=results_folder_postprocessing,
)

# Curation
if run_curation:
logger.info("Curating sorting")
sorting_curated = curate(
waveform_extractor=waveform_extractor,
curation_params=curation_params,
scratch_folder=scratch_folder,
results_folder=results_folder_curation,
)
else:
logger.info("Skipping curation")
else:
logger.info("Skipping postprocessing")
waveform_extractor = None

else:
logger.info("Skipping spike sorting")
sorting = None
waveform_extractor = None
sorting_curated = None


# TODO: Curation

# TODO: Visualization
# Visualization
visualization_output = None
if run_visualization:
logger.info("Visualizing results")
visualization_output = visualize(
recording=recording_preprocessed,
sorting_curated=sorting_curated,
waveform_extractor=waveform_extractor,
visualization_params=visualization_params,
scratch_folder=scratch_folder,
results_folder=results_folder_visualization,
)

return (recording_preprocessed, sorting, waveform_extractor)
return (recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output)
2 changes: 2 additions & 0 deletions src/spikeinterface_pipelines/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .visualization import visualize
from .params import VisualizationParams
108 changes: 108 additions & 0 deletions src/spikeinterface_pipelines/visualization/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from cProfile import label
from pydantic import BaseModel, Field
from typing import Literal, Union

from spikeinterface.widgets import sorting_summary


class TracesParams(BaseModel):
"""
Traces parameters.
"""
n_snippets_per_segment: int = Field(default=2, description="Number of snippets per segment to visualize.")
snippet_duration_s: float = Field(default=0.5, description="Duration of each snippet in seconds.")
skip: bool = Field(default=False, description="Skip traces visualization.")


class DetectionParams(BaseModel):
"""
Detection parameters.
"""
peak_sign: Literal["neg", "pos", "both"] = Field(default="neg", description="Peak sign for peak detection.")
detect_threshold: float = Field(default=5.0, description="Threshold for peak detection.")
exclude_sweep_ms: float = Field(default=0.1, description="Exclude sweep in ms around peak detection.")


class LocalizationParams(BaseModel):
"""
Localization parameters.
"""
ms_before: float = Field(default=0.1, description="Time before peak in ms.")
ms_after: float = Field(default=0.3, description="Time after peak in ms.")
radius_um: float = Field(default=100.0, description="Radius in um for sparsifying waveforms before localization.")


class DriftParams(BaseModel):
"""
Drift parameters.
"""
detection: DetectionParams = Field(
default=DetectionParams(),
description="Detection parameters (only used if spike localization was not performed in postprocessing)"
)
localization: LocalizationParams = Field(
default=LocalizationParams(),
description="Localization parameters (only used if spike localization was not performed in postprocessing)"
)
decimation_factor: int = Field(
default=30,
description="The decimation factor for drift visualization. E.g. 30 means that 1 out of 30 spikes is plotted."
)
alpha: float = Field(default=0.15, description="Alpha for scatter plot.")
vmin: float = Field(default=-200, description="Min value for colormap.")
vmax: float = Field(default=0, description="Max value for colormap.")
cmap: str = Field(default="Greys_r", description="Matplotlib colormap for drift visualization.")
figsize: Union[list, tuple] = Field(default=(10, 10), description="Figure size for drift visualization.")


class SortingSummaryVisualizationParams(BaseModel):
"""
Sorting summary visualization parameters.
"""
unit_table_properties: list = Field(
default=["default_qc"],
description="List of properties to show in the unit table."
)
curation: bool = Field(
default=True,
description="Whether to show curation buttons."
)
label_choices: list = Field(
default=["SUA", "MUA", "noise"],
description="List of labels to choose from (if `curation=True`)"
)
label: str = Field(
default="Sorting summary from SI pipelines",
description="Label for the sorting summary."
)


class RecordingVisualizationParams(BaseModel):
"""
Recording visualization parameters.
"""
timeseries: TracesParams = Field(
default=TracesParams(),
description="Traces visualization parameters."
)
drift: DriftParams = Field(
default=DriftParams(),
description="Drift visualization parameters."
)
label: str = Field(
default="Recording visualization from SI pipelines",
description="Label for the recording."
)

class VisualizationParams(BaseModel):
"""
Visualization parameters.
"""
recording: RecordingVisualizationParams = Field(
default=RecordingVisualizationParams(),
description="Recording visualization parameters."
)
sorting_summary: SortingSummaryVisualizationParams = Field(
default=SortingSummaryVisualizationParams(),
description="Sorting summary visualization parameters."
)
Loading

0 comments on commit a513c55

Please sign in to comment.