Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization step #10

Merged
merged 14 commits into from
Feb 6, 2024
5 changes: 5 additions & 0 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
branches:
- main

env: # For the sortingview backend
KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }}
KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }}

concurrency: # Cancel previous workflows on the same pull request
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
Expand All @@ -27,6 +31,7 @@ jobs:
run: |
python -m pip install -U pip # Official recommended way
pip install pytest
pip install pyvips
pip install -e .
- name: Test pipeline with pytest
run: |
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spikeinterface_pipelines"
version = "0.0.6"
version = "0.0.8"
description = "Collection of standardized analysis pipelines based on SpikeInterfacee."
readme = "README.md"
authors = [
Expand All @@ -9,7 +9,11 @@ authors = [
{ name = "Luiz Tauffer", email = "[email protected]" },
]
requires-python = ">=3.8"
dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"]
dependencies = [
"spikeinterface[full,widgets]>=0.100.0",
"neo>=0.12.0",
"pydantic>=2.4.2",
]
keywords = [
"spikeinterface",
"spike sorting",
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

import spikeinterface.full as si
import spikeinterface as si
import spikeinterface.curation as sc

from ..logger import logger
Expand Down
50 changes: 45 additions & 5 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
import re
from typing import Tuple
import spikeinterface as si

Expand All @@ -8,6 +9,8 @@
from .preprocessing import preprocess, PreprocessingParams
from .spikesorting import spikesort, SpikeSortingParams
from .postprocessing import postprocess, PostprocessingParams
from .curation import curate, CurationParams
from .visualization import visualize, VisualizationParams


def run_pipeline(
Expand All @@ -18,13 +21,19 @@ def run_pipeline(
preprocessing_params: PreprocessingParams | dict = PreprocessingParams(),
spikesorting_params: SpikeSortingParams | dict = SpikeSortingParams(),
postprocessing_params: PostprocessingParams | dict = PostprocessingParams(),
curation_params: CurationParams | dict = CurationParams(),
visualization_params: VisualizationParams | dict = VisualizationParams(),
run_preprocessing: bool = True,
run_spikesorting: bool = True,
run_postprocessing: bool = True,
run_curation: bool = True,
run_visualization: bool = True,
) -> Tuple[
si.BaseRecording | None,
si.BaseSorting | None,
si.WaveformExtractor | None
si.WaveformExtractor | None,
si.BaseSorting | None,
dict | None,
]:
# Create folders
results_folder = Path(results_folder)
Expand All @@ -36,6 +45,8 @@ def run_pipeline(
results_folder_preprocessing = results_folder / "preprocessing"
results_folder_spikesorting = results_folder / "spikesorting"
results_folder_postprocessing = results_folder / "postprocessing"
results_folder_curation = results_folder / "curation"
results_folder_visualization = results_folder / "visualization"

# Arguments Models validation, in case of dict
if isinstance(job_kwargs, dict):
Expand All @@ -46,6 +57,10 @@ def run_pipeline(
spikesorting_params = SpikeSortingParams(**spikesorting_params)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
if isinstance(curation_params, dict):
curation_params = CurationParams(**curation_params)
if isinstance(visualization_params, dict):
visualization_params = VisualizationParams(**visualization_params)

# set global job kwargs
si.set_global_job_kwargs(**job_kwargs.model_dump())
Expand Down Expand Up @@ -77,6 +92,7 @@ def run_pipeline(
raise Exception("Spike sorting failed")

# Postprocessing
sorting_curated = sorting
if run_postprocessing:
logger.info("Postprocessing sorting")
waveform_extractor = postprocess(
Expand All @@ -86,16 +102,40 @@ def run_pipeline(
scratch_folder=scratch_folder,
results_folder=results_folder_postprocessing,
)

# Curation
if run_curation:
logger.info("Curating sorting")
sorting_curated = curate(
waveform_extractor=waveform_extractor,
curation_params=curation_params,
scratch_folder=scratch_folder,
results_folder=results_folder_curation,
)
else:
logger.info("Skipping curation")
else:
logger.info("Skipping postprocessing")
waveform_extractor = None

else:
logger.info("Skipping spike sorting")
sorting = None
waveform_extractor = None
sorting_curated = None


# TODO: Curation

# TODO: Visualization
# Visualization
visualization_output = None
if run_visualization:
logger.info("Visualizing results")
visualization_output = visualize(
recording=recording_preprocessed,
sorting_curated=sorting_curated,
waveform_extractor=waveform_extractor,
visualization_params=visualization_params,
scratch_folder=scratch_folder,
results_folder=results_folder_visualization,
)

return (recording_preprocessed, sorting, waveform_extractor)
return (recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output)
2 changes: 2 additions & 0 deletions src/spikeinterface_pipelines/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .visualization import visualize
from .params import VisualizationParams
108 changes: 108 additions & 0 deletions src/spikeinterface_pipelines/visualization/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from cProfile import label
from pydantic import BaseModel, Field
from typing import Literal, Union

from spikeinterface.widgets import sorting_summary


class TracesParams(BaseModel):
"""
Traces parameters.
"""
n_snippets_per_segment: int = Field(default=2, description="Number of snippets per segment to visualize.")
snippet_duration_s: float = Field(default=0.5, description="Duration of each snippet in seconds.")
skip: bool = Field(default=False, description="Skip traces visualization.")


class DetectionParams(BaseModel):
"""
Detection parameters.
"""
peak_sign: Literal["neg", "pos", "both"] = Field(default="neg", description="Peak sign for peak detection.")
detect_threshold: float = Field(default=5.0, description="Threshold for peak detection.")
exclude_sweep_ms: float = Field(default=0.1, description="Exclude sweep in ms around peak detection.")


class LocalizationParams(BaseModel):
"""
Localization parameters.
"""
ms_before: float = Field(default=0.1, description="Time before peak in ms.")
ms_after: float = Field(default=0.3, description="Time after peak in ms.")
radius_um: float = Field(default=100.0, description="Radius in um for sparsifying waveforms before localization.")


class DriftParams(BaseModel):
"""
Drift parameters.
"""
detection: DetectionParams = Field(
default=DetectionParams(),
description="Detection parameters (only used if spike localization was not performed in postprocessing)"
)
localization: LocalizationParams = Field(
default=LocalizationParams(),
description="Localization parameters (only used if spike localization was not performed in postprocessing)"
)
decimation_factor: int = Field(
default=30,
description="The decimation factor for drift visualization. E.g. 30 means that 1 out of 30 spikes is plotted."
)
alpha: float = Field(default=0.15, description="Alpha for scatter plot.")
vmin: float = Field(default=-200, description="Min value for colormap.")
vmax: float = Field(default=0, description="Max value for colormap.")
cmap: str = Field(default="Greys_r", description="Matplotlib colormap for drift visualization.")
figsize: Union[list, tuple] = Field(default=(10, 10), description="Figure size for drift visualization.")


class SortingSummaryVisualizationParams(BaseModel):
"""
Sorting summary visualization parameters.
"""
unit_table_properties: list = Field(
default=["default_qc"],
description="List of properties to show in the unit table."
)
curation: bool = Field(
default=True,
description="Whether to show curation buttons."
)
label_choices: list = Field(
default=["SUA", "MUA", "noise"],
description="List of labels to choose from (if `curation=True`)"
)
label: str = Field(
default="Sorting summary from SI pipelines",
description="Label for the sorting summary."
)


class RecordingVisualizationParams(BaseModel):
"""
Recording visualization parameters.
"""
timeseries: TracesParams = Field(
default=TracesParams(),
description="Traces visualization parameters."
)
drift: DriftParams = Field(
default=DriftParams(),
description="Drift visualization parameters."
)
label: str = Field(
default="Recording visualization from SI pipelines",
description="Label for the recording."
)

class VisualizationParams(BaseModel):
"""
Visualization parameters.
"""
recording: RecordingVisualizationParams = Field(
default=RecordingVisualizationParams(),
description="Recording visualization parameters."
)
sorting_summary: SortingSummaryVisualizationParams = Field(
default=SortingSummaryVisualizationParams(),
description="Sorting summary visualization parameters."
)
Loading
Loading