Skip to content

Commit

Permalink
Merge pull request #1941 from samuelgarcia/pipeline_in_core
Browse files Browse the repository at this point in the history
move peak_pipeline into core and rename it as node_pipeline.
  • Loading branch information
samuelgarcia authored Sep 1, 2023
2 parents 23aef27 + 0ee1d11 commit 2e549a9
Show file tree
Hide file tree
Showing 18 changed files with 663 additions and 633 deletions.
605 changes: 605 additions & 0 deletions src/spikeinterface/core/node_pipeline.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
from pathlib import Path
import shutil

import scipy.signal
from spikeinterface import extract_waveforms, get_template_extremum_channel, generate_ground_truth_recording

from spikeinterface import download_dataset, BaseSorting
from spikeinterface.extractors import MEArecRecordingExtractor
# from spikeinterface.extractors import MEArecRecordingExtractor
from spikeinterface.extractors import read_mearec

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_pipeline import (
# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
PipelineNode,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
base_peak_dtype,
)


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "sortingcomponents"
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "sortingcomponents"
cache_folder = Path("cache_folder") / "core"


class AmplitudeExtractionNode(PipelineNode):
Expand Down Expand Up @@ -51,8 +51,8 @@ def get_dtype(self):
return np.dtype("float32")

def compute(self, traces, peaks, waveforms):
kernel = np.array([0.1, 0.8, 0.1])[np.newaxis, :, np.newaxis]
denoised_waveforms = scipy.signal.fftconvolve(waveforms, kernel, axes=1, mode="same")
kernel = np.array([0.1, 0.8, 0.1])
denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=waveforms)
return denoised_waveforms


Expand All @@ -69,16 +69,23 @@ def compute(self, traces, peaks, waveforms):


def test_run_node_pipeline():
repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data"
remote_path = "mearec/mearec_test_10s.h5"
local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None)
recording = MEArecRecordingExtractor(local_path)
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)

peaks = detect_peaks(
recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs
)
spikes = sorting.to_spike_vector()

# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
# print(extremum_channel_inds)
ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
# print(ext_channel_inds)
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = 0

# one step only : squeeze output
peak_retriever = PeakRetriever(recording, peaks)
Expand All @@ -93,19 +100,19 @@ def test_run_node_pipeline():
ms_before = 0.5
ms_after = 1.0
peak_retriever = PeakRetriever(recording, peaks)
extract_waveforms = ExtractDenseWaveforms(
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, extract_waveforms], return_output=False)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, extract_waveforms], return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_retriever, waveform_denoiser], return_output=True
)

nodes = [
peak_retriever,
extract_waveforms,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
Expand All @@ -129,6 +136,7 @@ def test_run_node_pipeline():
folder = cache_folder / "pipeline_folder"
if folder.is_dir():
shutil.rmtree(folder)

output = run_node_pipeline(
recording,
nodes,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def correct_motion(
from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
from spikeinterface.sortingcomponents.peak_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline

# get preset params and update if necessary
params = motion_options_preset[preset]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core import get_channel_distances
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeMonopolarTriangulation
from spikeinterface.sortingcomponents.peak_pipeline import (
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
PipelineNode,
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances

from spikeinterface.core.baserecording import BaseRecording
from spikeinterface.sortingcomponents.peak_pipeline import PeakDetector, WaveformsNode, ExtractSparseWaveforms
from spikeinterface.core.node_pipeline import (
PeakDetector,
WaveformsNode,
ExtractSparseWaveforms,
run_node_pipeline,
base_peak_dtype,
)

from ..core import get_chunk_with_margin

from .peak_pipeline import PeakDetector, run_node_pipeline, base_peak_dtype
from .tools import make_multi_method_doc

try:
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
from spikeinterface.core.job_tools import _shared_job_kwargs_doc, split_job_kwargs, fix_job_kwargs

from .peak_pipeline import (

from spikeinterface.core.node_pipeline import (
run_node_pipeline,
find_parent_of_type,
PeakRetriever,
Expand Down
Loading

0 comments on commit 2e549a9

Please sign in to comment.