Skip to content

Commit

Permalink
Feeback from Alessio
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 12, 2023
1 parent aa09cc3 commit 45f2b15
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
21 changes: 10 additions & 11 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
"""
This class is usefull to inject a sorting object in the node pipepline mechanisim.
This class is useful to inject a sorting object in the node pipepline mechanism.
It allows to compute some post processing with the same machinery used for sorting components.
This is a first step to totaly refactor:
This is used by:
* compute_spike_locations()
* compute_amplitude_scalings()
* compute_spike_amplitudes()
* compute_principal_components()
recording:
sorting:
channel_from_template: bool (default True)
If True then the channel_index is infered from template and extremum_channel_inds must be provided.
recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False every spikes compute its own channel index given a radius around the template max channel.
extremum_channel_inds: dict of int
The extremum channel index dict given from template.
Expand All @@ -174,7 +173,7 @@ def __init__(

assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds"

self.peaks = sorting_to_peak(sorting, extremum_channel_inds)
self.peaks = sorting_to_peaks(sorting, extremum_channel_inds)

if not channel_from_template:
channel_distance = get_channel_distances(recording)
Expand Down Expand Up @@ -223,7 +222,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)


def sorting_to_peak(sorting, extremum_channel_inds):
def sorting_to_peaks(sorting, extremum_channel_inds):
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=base_peak_dtype)
peaks["sample_index"] = spikes["sample_index"]
Expand Down
9 changes: 4 additions & 5 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SpikeRetriever,
PipelineNode,
ExtractDenseWaveforms,
sorting_to_peak,
sorting_to_peaks,
)


Expand Down Expand Up @@ -72,15 +72,14 @@ def compute(self, traces, peaks, waveforms):
def test_run_node_pipeline():
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])

# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)
job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)

spikes = sorting.to_spike_vector()

# create peaks from spikes
we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index")
peaks = sorting_to_peak(sorting, extremum_channel_inds)
peaks = sorting_to_peaks(sorting, extremum_channel_inds)

peak_retriever = PeakRetriever(recording, peaks)
# channel index is from template
Expand All @@ -97,7 +96,7 @@ def test_run_node_pipeline():
peak_sign="neg",
)

# test with 2 diffrents first node
# test with 3 differents first nodes
for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)):
# one step only : squeeze output
nodes = [
Expand Down

0 comments on commit 45f2b15

Please sign in to comment.