-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from SpikeInterface/porting-code
porting code for pipeline modules
- Loading branch information
Showing
18 changed files
with
918 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: Testing pipeline | ||
|
||
on: | ||
pull_request: | ||
types: [synchronize, opened, reopened] | ||
branches: | ||
- main | ||
|
||
concurrency: # Cancel previous workflows on the same pull request | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-and-test: | ||
name: Test on ${{ matrix.os }} OS | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: ["ubuntu-latest", "macos-latest", "windows-latest"] | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.10' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install -U pip # Official recommended way | ||
pip install pytest | ||
pip install -e . | ||
- name: Test pipeline with pytest | ||
run: | | ||
pytest -v | ||
shell: bash # Necessary for pipeline to work on windows |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: Build and publish spikeinterface_pipelines Pyton package | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- src/spikeinterface_pipelines/** | ||
|
||
jobs: | ||
pypi-release: | ||
name: PyPI release | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install build | ||
pip install twine | ||
- name: Build and publish to PyPI | ||
run: | | ||
python -m build | ||
twine upload dist/* | ||
env: | ||
TWINE_USERNAME: __token__ | ||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
[project] | ||
name = "spikeinterface_pipelines" | ||
version = "0.0.2" | ||
description = "Collection of standardized analysis pipelines based on SpikeInterfacee." | ||
readme = "README.md" | ||
authors = [ | ||
{ name = "Alessio Buccino", email = "[email protected]" }, | ||
{ name = "Jeremy Magland", email = "[email protected]" }, | ||
{ name = "Luiz Tauffer", email = "[email protected]" }, | ||
] | ||
requires-python = ">=3.8" | ||
dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"] | ||
keywords = [ | ||
"spikeinterface", | ||
"spike sorting", | ||
"electrophysiology", | ||
"neuroscience", | ||
] | ||
|
||
[project.urls] | ||
homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines" | ||
documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines" | ||
repository = "https://github.com/SpikeInterface/spikeinterface_pipelines" | ||
|
||
[build-system] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools] | ||
package-dir = { "" = "src" } | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[tool.black] | ||
line-length = 120 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pipeline import run_pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class JobKwargs(BaseModel): | ||
n_jobs: int = Field(default=-1, description="The number of jobs to run in parallel.") | ||
chunk_duration: str = Field(default="1s", description="The duration of the chunks to process.") | ||
progress_bar: bool = Field(default=False, description="Whether to display a progress bar.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import logging | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from pathlib import Path | ||
import re | ||
from typing import Tuple | ||
|
||
import spikeinterface as si | ||
|
||
from .logger import logger | ||
from .global_params import JobKwargs | ||
from .preprocessing import preprocess, PreprocessingParams | ||
from .spikesorting import spikesort, SpikeSortingParams | ||
from .postprocessing import postprocess, PostprocessingParams | ||
|
||
|
||
# TODO - WIP | ||
def run_pipeline( | ||
recording: si.BaseRecording, | ||
scratch_folder: Path = Path("./scratch/"), | ||
results_folder: Path = Path("./results/"), | ||
job_kwargs: JobKwargs = JobKwargs(), | ||
preprocessing_params: PreprocessingParams = PreprocessingParams(), | ||
spikesorting_params: SpikeSortingParams = SpikeSortingParams(), | ||
postprocessing_params: PostprocessingParams = PostprocessingParams(), | ||
run_preprocessing: bool = True, | ||
) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: | ||
# Create folders | ||
scratch_folder.mkdir(exist_ok=True, parents=True) | ||
results_folder.mkdir(exist_ok=True, parents=True) | ||
|
||
# Paths | ||
results_folder_preprocessing = results_folder / "preprocessing" | ||
results_folder_spikesorting = results_folder / "spikesorting" | ||
results_folder_postprocessing = results_folder / "postprocessing" | ||
|
||
# set global job kwargs | ||
si.set_global_job_kwargs(**job_kwargs.model_dump()) | ||
|
||
# Preprocessing | ||
if run_preprocessing: | ||
logger.info("Preprocessing recording") | ||
recording_preprocessed = preprocess( | ||
recording=recording, | ||
preprocessing_params=preprocessing_params, | ||
scratch_folder=scratch_folder, | ||
results_folder=results_folder_preprocessing, | ||
) | ||
if recording_preprocessed is None: | ||
raise Exception("Preprocessing failed") | ||
else: | ||
logger.info("Skipping preprocessing") | ||
recording_preprocessed = recording | ||
|
||
# Spike Sorting | ||
sorting = spikesort( | ||
recording=recording_preprocessed, | ||
scratch_folder=scratch_folder, | ||
spikesorting_params=spikesorting_params, | ||
results_folder=results_folder_spikesorting, | ||
) | ||
if sorting is None: | ||
raise Exception("Spike sorting failed") | ||
|
||
# Postprocessing | ||
waveform_extractor = postprocess( | ||
recording=recording_preprocessed, | ||
sorting=sorting, | ||
postprocessing_params=postprocessing_params, | ||
scratch_folder=scratch_folder, | ||
results_folder=results_folder_postprocessing, | ||
) | ||
|
||
# TODO: Curation | ||
|
||
# TODO: Visualization | ||
|
||
return (recording_preprocessed, sorting, waveform_extractor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .postprocessing import postprocess | ||
from .params import PostprocessingParams |
Oops, something went wrong.