Skip to content

Commit

Permalink
Merge pull request #2 from SpikeInterface/porting-code
Browse files Browse the repository at this point in the history
porting code for pipeline modules
  • Loading branch information
luiztauffer authored Nov 15, 2023
2 parents b3563e2 + 7d7c26d commit cae6fd4
Show file tree
Hide file tree
Showing 18 changed files with 918 additions and 151 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Testing pipeline

on:
pull_request:
types: [synchronize, opened, reopened]
branches:
- main

concurrency: # Cancel previous workflows on the same pull request
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-and-test:
name: Test on ${{ matrix.os }} OS
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install -U pip # Official recommended way
pip install pytest
pip install -e .
- name: Test pipeline with pytest
run: |
pytest -v
shell: bash # Necessary for pipeline to work on windows
31 changes: 31 additions & 0 deletions .github/workflows/pypi_release.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Build and publish spikeinterface_pipelines Pyton package

on:
push:
branches:
- main
paths:
- src/spikeinterface_pipelines/**

jobs:
pypi-release:
name: PyPI release
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
pip install twine
- name: Build and publish to PyPI
run: |
python -m build
twine upload dist/*
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
151 changes: 0 additions & 151 deletions in_progress/hd_shank_preprocessing.py

This file was deleted.

36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[project]
name = "spikeinterface_pipelines"
version = "0.0.2"
description = "Collection of standardized analysis pipelines based on SpikeInterfacee."
readme = "README.md"
authors = [
{ name = "Alessio Buccino", email = "[email protected]" },
{ name = "Jeremy Magland", email = "[email protected]" },
{ name = "Luiz Tauffer", email = "[email protected]" },
]
requires-python = ">=3.8"
dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"]
keywords = [
"spikeinterface",
"spike sorting",
"electrophysiology",
"neuroscience",
]

[project.urls]
homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines"
documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines"
repository = "https://github.com/SpikeInterface/spikeinterface_pipelines"

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
package-dir = { "" = "src" }

[tool.setuptools.packages.find]
where = ["src"]

[tool.black]
line-length = 120
1 change: 1 addition & 0 deletions src/spikeinterface_pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pipeline import run_pipeline
7 changes: 7 additions & 0 deletions src/spikeinterface_pipelines/global_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pydantic import BaseModel, Field


class JobKwargs(BaseModel):
n_jobs: int = Field(default=-1, description="The number of jobs to run in parallel.")
chunk_duration: str = Field(default="1s", description="The duration of the chunks to process.")
progress_bar: bool = Field(default=False, description="Whether to display a progress bar.")
4 changes: 4 additions & 0 deletions src/spikeinterface_pipelines/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
75 changes: 75 additions & 0 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path
import re
from typing import Tuple

import spikeinterface as si

from .logger import logger
from .global_params import JobKwargs
from .preprocessing import preprocess, PreprocessingParams
from .spikesorting import spikesort, SpikeSortingParams
from .postprocessing import postprocess, PostprocessingParams


# TODO - WIP
def run_pipeline(
recording: si.BaseRecording,
scratch_folder: Path = Path("./scratch/"),
results_folder: Path = Path("./results/"),
job_kwargs: JobKwargs = JobKwargs(),
preprocessing_params: PreprocessingParams = PreprocessingParams(),
spikesorting_params: SpikeSortingParams = SpikeSortingParams(),
postprocessing_params: PostprocessingParams = PostprocessingParams(),
run_preprocessing: bool = True,
) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]:
# Create folders
scratch_folder.mkdir(exist_ok=True, parents=True)
results_folder.mkdir(exist_ok=True, parents=True)

# Paths
results_folder_preprocessing = results_folder / "preprocessing"
results_folder_spikesorting = results_folder / "spikesorting"
results_folder_postprocessing = results_folder / "postprocessing"

# set global job kwargs
si.set_global_job_kwargs(**job_kwargs.model_dump())

# Preprocessing
if run_preprocessing:
logger.info("Preprocessing recording")
recording_preprocessed = preprocess(
recording=recording,
preprocessing_params=preprocessing_params,
scratch_folder=scratch_folder,
results_folder=results_folder_preprocessing,
)
if recording_preprocessed is None:
raise Exception("Preprocessing failed")
else:
logger.info("Skipping preprocessing")
recording_preprocessed = recording

# Spike Sorting
sorting = spikesort(
recording=recording_preprocessed,
scratch_folder=scratch_folder,
spikesorting_params=spikesorting_params,
results_folder=results_folder_spikesorting,
)
if sorting is None:
raise Exception("Spike sorting failed")

# Postprocessing
waveform_extractor = postprocess(
recording=recording_preprocessed,
sorting=sorting,
postprocessing_params=postprocessing_params,
scratch_folder=scratch_folder,
results_folder=results_folder_postprocessing,
)

# TODO: Curation

# TODO: Visualization

return (recording_preprocessed, sorting, waveform_extractor)
2 changes: 2 additions & 0 deletions src/spikeinterface_pipelines/postprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .postprocessing import postprocess
from .params import PostprocessingParams
Loading

0 comments on commit cae6fd4

Please sign in to comment.