Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Spike retriever #1944

Merged
merged 10 commits into from
Sep 13, 2023
99 changes: 96 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar
raise NotImplementedError


# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# as first element they play the same role in pipeline : give some peaks (and eventually more)


Expand Down Expand Up @@ -138,7 +138,100 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
pass
"""
This class is usefull to inject a sorting object in the node pipepline mechanisim.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
It allows to compute some post processing with the same machinery used for sorting components.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
This is a first step to totaly refactor:
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
* compute_spike_locations()
* compute_amplitude_scalings()
* compute_spike_amplitudes()
* compute_principal_components()


recording:

sorting:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing

channel_from_template: bool (default True)
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
If True then the channel_index is infered from template and extremum_channel_inds must be provided.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
If False every spikes compute its own channel index given a radius around the template max channel.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
extremum_channel_inds: dict of int
The extremum channel index dict given from template.
radius_um: float (default 50.)
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: str (default "neg")
Peak sign to find the max channel.
Used only when channel_from_template=False
"""

def __init__(
self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg"
):
PipelineNode.__init__(self, recording, return_output=False)

self.channel_from_template = channel_from_template

assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds"
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

self.peaks = sorting_to_peak(sorting, extremum_channel_inds)

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.peak_sign = peak_sign

# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
return 0

def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
local_peaks = local_peaks.copy()
local_peaks["sample_index"] -= start_frame - max_margin

if not self.channel_from_template:
# handle channel spike per spike
for i, peak in enumerate(local_peaks):
chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]])
sparse_wfs = traces[peak["sample_index"], chans]
if self.peak_sign == "neg":
local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)]
elif self.peak_sign == "pos":
local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)]
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???

return (local_peaks,)


def sorting_to_peak(sorting, extremum_channel_inds):
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = spikes["segment_index"]
return peaks


class WaveformsNode(PipelineNode):
Expand Down Expand Up @@ -423,7 +516,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakRetriever):
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down
192 changes: 101 additions & 91 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
SpikeRetriever,
PipelineNode,
ExtractDenseWaveforms,
base_peak_dtype,
sorting_to_peak,
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -71,106 +72,115 @@ def compute(self, traces, peaks, waveforms):
def test_run_node_pipeline():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

spikes = sorting.to_spike_vector()

# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
# print(extremum_channel_inds)
ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
# print(ext_channel_inds)
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = 0

# one step only : squeeze output
peak_retriever = PeakRetriever(recording, peaks)
nodes = [
peak_retriever,
AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peaks = sorting_to_peak(sorting, extremum_channel_inds)

peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_retriever, waveform_denoiser], return_output=True
# channel index is from template
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)

nodes = [
peak_retriever,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / "pipeline_folder"
if folder.is_dir():
shutil.rmtree(folder)

output = run_node_pipeline(
# channel index is per spike
spike_retriever_S = SpikeRetriever(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
sorting,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign="neg",
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)

# test with 2 diffrents first node
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)):
# one step only : squeeze output
nodes = [
peak_source,
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_source, waveform_denoiser], return_output=True
)

nodes = [
peak_source,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / f"pipeline_folder_{loop}"
if folder.is_dir():
shutil.rmtree(folder)
output = run_node_pipeline(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)


if __name__ == "__main__":
Expand Down