Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move peak_pipeline into core and rename it as node_pipeline. #1941

Merged
merged 17 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
8d5e408
move peak_pipeline into core and rename it as node_pipeline.
samuelgarcia Aug 28, 2023
9ee00db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2023
a516c63
oups
samuelgarcia Aug 28, 2023
3823bdb
Merge branch 'pipeline_in_core' of github.com:samuelgarcia/spikeinter…
samuelgarcia Aug 28, 2023
e7a4c86
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2023
da7a68b
remove scipy from core test
samuelgarcia Aug 29, 2023
6643ebe
Merge branch 'pipeline_in_core' of github.com:samuelgarcia/spikeinter…
samuelgarcia Aug 29, 2023
e8bae07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2023
85809b6
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Sep 1, 2023
78cc771
Merge branch 'generator' of github.com:samuelgarcia/spikeinterface in…
samuelgarcia Sep 1, 2023
b50bc90
Remove download from test_node_pipeline.py when in core.
samuelgarcia Sep 1, 2023
d07da4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2023
4acf5c3
Merge branch 'generator' of github.com:samuelgarcia/spikeinterface in…
samuelgarcia Sep 1, 2023
2195df5
Merge branch 'pipeline_in_core' of github.com:samuelgarcia/spikeinter…
samuelgarcia Sep 1, 2023
8782d5f
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Sep 1, 2023
748751d
remove test_peak_pipepeline.py from components (this is now in core)
samuelgarcia Sep 1, 2023
0ee1d11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
605 changes: 605 additions & 0 deletions src/spikeinterface/core/node_pipeline.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
from pathlib import Path
import shutil

import scipy.signal
from spikeinterface import extract_waveforms, get_template_extremum_channel, generate_ground_truth_recording

from spikeinterface import download_dataset, BaseSorting
from spikeinterface.extractors import MEArecRecordingExtractor
# from spikeinterface.extractors import MEArecRecordingExtractor
from spikeinterface.extractors import read_mearec

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_pipeline import (
# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
PipelineNode,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
base_peak_dtype,
)


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "sortingcomponents"
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "sortingcomponents"
cache_folder = Path("cache_folder") / "core"


class AmplitudeExtractionNode(PipelineNode):
Expand Down Expand Up @@ -51,8 +51,8 @@ def get_dtype(self):
return np.dtype("float32")

def compute(self, traces, peaks, waveforms):
kernel = np.array([0.1, 0.8, 0.1])[np.newaxis, :, np.newaxis]
denoised_waveforms = scipy.signal.fftconvolve(waveforms, kernel, axes=1, mode="same")
kernel = np.array([0.1, 0.8, 0.1])
denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=waveforms)
return denoised_waveforms


Expand All @@ -69,16 +69,23 @@ def compute(self, traces, peaks, waveforms):


def test_run_node_pipeline():
repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data"
remote_path = "mearec/mearec_test_10s.h5"
local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None)
recording = MEArecRecordingExtractor(local_path)
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)

peaks = detect_peaks(
recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs
)
spikes = sorting.to_spike_vector()

# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
# print(extremum_channel_inds)
ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
# print(ext_channel_inds)
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = 0

# one step only : squeeze output
peak_retriever = PeakRetriever(recording, peaks)
Expand All @@ -93,19 +100,19 @@ def test_run_node_pipeline():
ms_before = 0.5
ms_after = 1.0
peak_retriever = PeakRetriever(recording, peaks)
extract_waveforms = ExtractDenseWaveforms(
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, extract_waveforms], return_output=False)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, extract_waveforms], return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_retriever, waveform_denoiser], return_output=True
)

nodes = [
peak_retriever,
extract_waveforms,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
Expand All @@ -129,6 +136,7 @@ def test_run_node_pipeline():
folder = cache_folder / "pipeline_folder"
if folder.is_dir():
shutil.rmtree(folder)

output = run_node_pipeline(
recording,
nodes,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def correct_motion(
from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
from spikeinterface.sortingcomponents.peak_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline

# get preset params and update if necessary
params = motion_options_preset[preset]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core import get_channel_distances
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeMonopolarTriangulation
from spikeinterface.sortingcomponents.peak_pipeline import (
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
PipelineNode,
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances

from spikeinterface.core.baserecording import BaseRecording
from spikeinterface.sortingcomponents.peak_pipeline import PeakDetector, WaveformsNode, ExtractSparseWaveforms
from spikeinterface.core.node_pipeline import (
PeakDetector,
WaveformsNode,
ExtractSparseWaveforms,
run_node_pipeline,
base_peak_dtype,
)

from ..core import get_chunk_with_margin

from .peak_pipeline import PeakDetector, run_node_pipeline, base_peak_dtype
from .tools import make_multi_method_doc

try:
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
from spikeinterface.core.job_tools import _shared_job_kwargs_doc, split_job_kwargs, fix_job_kwargs

from .peak_pipeline import (

from spikeinterface.core.node_pipeline import (
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
Expand Down
Loading