Skip to content

Commit

Permalink
Merge pull request #11 from SpikeInterface/spikesort-by-group
Browse files Browse the repository at this point in the history
Add option to spikesort by group
  • Loading branch information
alejoe91 authored Mar 16, 2024
2 parents f5bcba2 + 81b13e8 commit fb029c1
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def curate(
Curated sorting
"""
# get quality metrics
if not waveform_extractor.is_extension("quality_metrics"):
if not waveform_extractor.has_extension("quality_metrics"):
logger.info(f"[Curation] \tQuality metrics not found in WaveformExtractor.")
return

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class MountainSort5Model(BaseModel):

class SpikeSortingParams(BaseModel):
sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.")
spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field(
default=Kilosort25Model(), description="Sorter specific kwargs."
)
47 changes: 34 additions & 13 deletions src/spikeinterface_pipelines/spikesorting/spikesorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from pathlib import Path
import shutil
import numpy as np
from pathlib import Path

import spikeinterface.full as si
import spikeinterface.curation as sc

Expand Down Expand Up @@ -38,21 +40,34 @@ def spikesort(
try:
logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter")


## TEST ONLY - REMOVE LATER ##
# si.get_default_sorter_params('kilosort2_5')
# params_kilosort2_5 = {'do_correction': False}
## --------------------------##

sorting = si.run_sorter(
recording=recording,
sorter_name=spikesorting_params.sorter_name,
output_folder=str(output_folder),
verbose=True,
delete_output_folder=True,
remove_existing_folder=True,
**spikesorting_params.sorter_kwargs.model_dump(),
# **params_kilosort2_5
)
if spikesorting_params.spikesort_by_group and len(np.unique(recording.get_channel_groups())) > 1:
logger.info(f"[Spikesorting] \tSorting by channel groups")
sorting = si.run_sorter_by_property(
recording=recording,
sorter_name=spikesorting_params.sorter_name,
grouping_property="group",
working_folder=str(output_folder),
verbose=True,
delete_output_folder=True,
remove_existing_folder=True,
**spikesorting_params.sorter_kwargs.model_dump(),
)
else:
sorting = si.run_sorter(
recording=recording,
sorter_name=spikesorting_params.sorter_name,
output_folder=str(output_folder),
verbose=True,
delete_output_folder=True,
remove_existing_folder=True,
**spikesorting_params.sorter_kwargs.model_dump(),
)
logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units")
# remove spikes beyond num_Samples (if any)
sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording)
Expand All @@ -62,8 +77,14 @@ def spikesort(
except Exception as e:
# save log to results
results_folder.mkdir(exist_ok=True, parents=True)
if (output_folder).is_dir():
shutil.copy(output_folder / "spikeinterface_log.json", results_folder)
if not spikesorting_params.spikesort_by_group:
if (output_folder).is_dir():
shutil.copy(output_folder / "spikeinterface_log.json", results_folder)
shutil.rmtree(output_folder)
else:
for group_folder in output_folder.iterdir():
if group_folder.is_dir():
shutil.copy(group_folder / "spikeinterface_log.json", results_folder / group_folder.name)
shutil.rmtree(output_folder)
logger.info(f"Spike sorting error:\n{e}")
return None
17 changes: 10 additions & 7 deletions src/spikeinterface_pipelines/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ def visualize(
decimation_factor = recording_params["drift"]["decimation_factor"]
alpha = recording_params["drift"]["alpha"]

# use spike locations
if not waveform_extractor.has_extension("quality_metrics"):
logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations")
peaks = waveform_extractor.sorting.to_spike_vector()
peak_locations = waveform_extractor.load_extension("spike_locations").get_data()
peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data())
# check if spike locations are available
spike_locations_available = False
if waveform_extractor is not None:
if waveform_extractor.has_extension("spike_locations"):
logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations")
peaks = waveform_extractor.sorting.to_spike_vector()
peak_locations = waveform_extractor.load_extension("spike_locations").get_data()
peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data())
spike_locations_available = True
# otherwise detect peaks
else:
if not spike_locations_available:
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass
Expand Down
30 changes: 28 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def _generate_gt_recording():
recording, sorting = si.generate_ground_truth_recording(durations=[30], num_channels=64, seed=0)
recording, sorting = si.generate_ground_truth_recording(durations=[15], num_channels=128, seed=0)
# add inter sample shift (but fake)
inter_sample_shifts = np.zeros(recording.get_num_channels())
recording.set_property("inter_sample_shift", inter_sample_shifts)
Expand Down Expand Up @@ -69,15 +69,41 @@ def test_spikesorting(tmp_path, generate_recording):
results_folder = Path(tmp_path) / "results_spikesorting"
scratch_folder = Path(tmp_path) / "scratch_spikesorting"

ks25_params = Kilosort25Model(do_correction=False)
spikesorting_params = SpikeSortingParams(
sorter_name="kilosort2_5",
sorter_kwargs=ks25_params,
)

sorting = spikesort(
recording=recording,
spikesorting_params=SpikeSortingParams(),
spikesorting_params=spikesorting_params,
results_folder=results_folder,
scratch_folder=scratch_folder,
)

assert isinstance(sorting, si.BaseSorting)

# by group
num_channels = recording.get_num_channels()
groups = [0] * (num_channels // 2) + [1] * (num_channels // 2)
recording.set_channel_groups(groups)

spikesorting_params = SpikeSortingParams(
sorter_name="kilosort2_5",
sorter_kwargs=ks25_params,
spikesort_by_group=True,
)
sorting_group = spikesort(
recording=recording,
spikesorting_params=spikesorting_params,
results_folder=results_folder,
scratch_folder=scratch_folder,
)

assert isinstance(sorting_group, si.BaseSorting)
assert "group" in sorting_group.get_property_keys()


def test_postprocessing(tmp_path, generate_recording):
recording, sorting, _ = generate_recording
Expand Down

0 comments on commit fb029c1

Please sign in to comment.