Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Spike retriever #1944

Merged
merged 10 commits into from
Sep 13, 2023
98 changes: 95 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar
raise NotImplementedError


# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# as first element they play the same role in pipeline : give some peaks (and eventually more)


Expand Down Expand Up @@ -138,7 +138,99 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
pass
"""
This class is useful to inject a sorting object in the node pipepline mechanism.
It allows to compute some post-processing steps with the same machinery used for sorting components.
This is used by:
* compute_spike_locations()
* compute_amplitude_scalings()
* compute_spike_amplitudes()
* compute_principal_components()

recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False, the max channel is computed for each spike given a radius around the template max channel.
extremum_channel_inds: dict of int
The extremum channel index dict given from template.
radius_um: float (default 50.)
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: str (default "neg")
Peak sign to find the max channel.
Used only when channel_from_template=False
"""

def __init__(
self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg"
):
PipelineNode.__init__(self, recording, return_output=False)

self.channel_from_template = channel_from_template

assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds"
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

self.peaks = sorting_to_peaks(sorting, extremum_channel_inds)

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.peak_sign = peak_sign

# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
return 0

def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
local_peaks = local_peaks.copy()
local_peaks["sample_index"] -= start_frame - max_margin

if not self.channel_from_template:
# handle channel spike per spike
for i, peak in enumerate(local_peaks):
chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]])
sparse_wfs = traces[peak["sample_index"], chans]
if self.peak_sign == "neg":
local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)]
elif self.peak_sign == "pos":
local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)]
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???

return (local_peaks,)


def sorting_to_peaks(sorting, extremum_channel_inds):
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = spikes["segment_index"]
return peaks


class WaveformsNode(PipelineNode):
Expand Down Expand Up @@ -423,7 +515,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakRetriever):
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down
189 changes: 99 additions & 90 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
SpikeRetriever,
PipelineNode,
ExtractDenseWaveforms,
base_peak_dtype,
sorting_to_peaks,
)


Expand Down Expand Up @@ -78,99 +79,107 @@ def test_run_node_pipeline():
# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
# print(extremum_channel_inds)
ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
# print(ext_channel_inds)
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = 0

# one step only : squeeze output
peak_retriever = PeakRetriever(recording, peaks)
nodes = [
peak_retriever,
AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peaks = sorting_to_peaks(sorting, extremum_channel_inds)

peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_retriever, waveform_denoiser], return_output=True
# channel index is from template
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)

nodes = [
peak_retriever,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / "pipeline_folder"
if folder.is_dir():
shutil.rmtree(folder)

output = run_node_pipeline(
# channel index is per spike
spike_retriever_S = SpikeRetriever(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
sorting,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign="neg",
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)

# test with 3 differents first nodes
for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)):
# one step only : squeeze output
nodes = [
peak_source,
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_source, waveform_denoiser], return_output=True
)

nodes = [
peak_source,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / f"pipeline_folder_{loop}"
if folder.is_dir():
shutil.rmtree(folder)
output = run_node_pipeline(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)


if __name__ == "__main__":
Expand Down