Skip to content

Commit

Permalink
Merge pull request #1944 from samuelgarcia/spike_retriever
Browse files Browse the repository at this point in the history
Implement Spike retriever
  • Loading branch information
alejoe91 authored Sep 13, 2023
2 parents b240298 + ad0f05e commit 904d6ee
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 93 deletions.
98 changes: 95 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar
raise NotImplementedError


# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever)
# as first element they play the same role in pipeline : give some peaks (and eventually more)


Expand Down Expand Up @@ -138,7 +138,99 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
pass
"""
This class is useful to inject a sorting object in the node pipepline mechanism.
It allows to compute some post-processing steps with the same machinery used for sorting components.
This is used by:
* compute_spike_locations()
* compute_amplitude_scalings()
* compute_spike_amplitudes()
* compute_principal_components()
recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False, the max channel is computed for each spike given a radius around the template max channel.
extremum_channel_inds: dict of int
The extremum channel index dict given from template.
radius_um: float (default 50.)
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: str (default "neg")
Peak sign to find the max channel.
Used only when channel_from_template=False
"""

def __init__(
self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg"
):
PipelineNode.__init__(self, recording, return_output=False)

self.channel_from_template = channel_from_template

assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary"

self.peaks = sorting_to_peaks(sorting, extremum_channel_inds)

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.peak_sign = peak_sign

# precompute segment slice
self.segment_slices = []
for segment_index in range(recording.get_num_segments()):
i0 = np.searchsorted(self.peaks["segment_index"], segment_index)
i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1)
self.segment_slices.append(slice(i0, i1))

def get_trace_margin(self):
return 0

def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame)
i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame)
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
local_peaks = local_peaks.copy()
local_peaks["sample_index"] -= start_frame - max_margin

if not self.channel_from_template:
# handle channel spike per spike
for i, peak in enumerate(local_peaks):
chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]])
sparse_wfs = traces[peak["sample_index"], chans]
if self.peak_sign == "neg":
local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)]
elif self.peak_sign == "pos":
local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)]
elif self.peak_sign == "both":
local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))]

# TODO: "amplitude" ???

return (local_peaks,)


def sorting_to_peaks(sorting, extremum_channel_inds):
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = spikes["segment_index"]
return peaks


class WaveformsNode(PipelineNode):
Expand Down Expand Up @@ -423,7 +515,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakRetriever):
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down
189 changes: 99 additions & 90 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
PeakRetriever,
SpikeRetriever,
PipelineNode,
ExtractDenseWaveforms,
base_peak_dtype,
sorting_to_peaks,
)


Expand Down Expand Up @@ -78,99 +79,107 @@ def test_run_node_pipeline():
# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
# print(extremum_channel_inds)
ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids])
# print(ext_channel_inds)
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]]
peaks["amplitude"] = 0.0
peaks["segment_index"] = 0

# one step only : squeeze output
peak_retriever = PeakRetriever(recording, peaks)
nodes = [
peak_retriever,
AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peaks = sorting_to_peaks(sorting, extremum_channel_inds)

peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_retriever, waveform_denoiser], return_output=True
# channel index is from template
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)

nodes = [
peak_retriever,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / "pipeline_folder"
if folder.is_dir():
shutil.rmtree(folder)

output = run_node_pipeline(
# channel index is per spike
spike_retriever_S = SpikeRetriever(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
sorting,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
peak_sign="neg",
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)

# test with 3 differents first nodes
for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)):
# one step only : squeeze output
nodes = [
peak_source,
AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6),
]
step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True)
assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"])

# 3 nodes two have outputs
ms_before = 0.5
ms_after = 1.0
peak_retriever = PeakRetriever(recording, peaks)
dense_waveforms = ExtractDenseWaveforms(
recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False
)
waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False)
amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True)
waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True)
denoised_waveforms_rms = WaveformsRootMeanSquare(
recording, parents=[peak_source, waveform_denoiser], return_output=True
)

nodes = [
peak_source,
dense_waveforms,
waveform_denoiser,
amplitue_extraction,
waveforms_rms,
denoised_waveforms_rms,
]

# gather memory mode
output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory")
amplitudes, waveforms_rms, denoised_waveforms_rms = output
assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"])

num_peaks = peaks.shape[0]
num_channels = recording.get_num_channels()
assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

assert waveforms_rms.shape[0] == num_peaks
assert waveforms_rms.shape[1] == num_channels

# gather npy mode
folder = cache_folder / f"pipeline_folder_{loop}"
if folder.is_dir():
shutil.rmtree(folder)
output = run_node_pipeline(
recording,
nodes,
job_kwargs,
gather_mode="npy",
folder=folder,
names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"],
)
amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output

amplitudes_file = folder / "amplitudes.npy"
assert amplitudes_file.is_file()
amplitudes3 = np.load(amplitudes_file)
assert np.array_equal(amplitudes, amplitudes2)
assert np.array_equal(amplitudes2, amplitudes3)

waveforms_rms_file = folder / "waveforms_rms.npy"
assert waveforms_rms_file.is_file()
waveforms_rms3 = np.load(waveforms_rms_file)
assert np.array_equal(waveforms_rms, waveforms_rms2)
assert np.array_equal(waveforms_rms2, waveforms_rms3)

denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy"
assert denoised_waveforms_rms_file.is_file()
denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file)
assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2)
assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3)

# Test pickle mechanism
for node in nodes:
import pickle

pickled_node = pickle.dumps(node)
unpickled_node = pickle.loads(pickled_node)


if __name__ == "__main__":
Expand Down

0 comments on commit 904d6ee

Please sign in to comment.