diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 627687f..2a4070a 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -62,6 +62,7 @@ class MountainSort5Model(BaseModel): class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") + spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.") sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( default=Kilosort25Model(), description="Sorter specific kwargs." ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index d9243c1..900748e 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,6 +1,8 @@ from __future__ import annotations -from pathlib import Path import shutil +import numpy as np +from pathlib import Path + import spikeinterface.full as si import spikeinterface.curation as sc @@ -38,21 +40,34 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") + ## TEST ONLY - REMOVE LATER ## # si.get_default_sorter_params('kilosort2_5') # params_kilosort2_5 = {'do_correction': False} ## --------------------------## - sorting = si.run_sorter( - recording=recording, - sorter_name=spikesorting_params.sorter_name, - output_folder=str(output_folder), - verbose=True, - delete_output_folder=True, - remove_existing_folder=True, - **spikesorting_params.sorter_kwargs.model_dump(), - # **params_kilosort2_5 - ) + if spikesorting_params.spikesort_by_group and len(np.unique(recording.get_channel_groups())) > 1: + logger.info(f"[Spikesorting] \tSorting by channel groups") + sorting = si.run_sorter_by_property( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + grouping_property="group", + working_folder=str(output_folder), + verbose=True, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) + else: + sorting = si.run_sorter( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + output_folder=str(output_folder), + verbose=True, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any) sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording) @@ -62,8 +77,14 @@ def spikesort( except Exception as e: # save log to results results_folder.mkdir(exist_ok=True, parents=True) - if (output_folder).is_dir(): - shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + if not spikesorting_params.spikesort_by_group: + if (output_folder).is_dir(): + shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + shutil.rmtree(output_folder) + else: + for group_folder in output_folder.iterdir(): + if group_folder.is_dir(): + shutil.copy(group_folder / "spikeinterface_log.json", results_folder / group_folder.name) shutil.rmtree(output_folder) logger.info(f"Spike sorting error:\n{e}") return None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5cf5ce1..5edcc54 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -22,7 +22,7 @@ def _generate_gt_recording(): - recording, sorting = si.generate_ground_truth_recording(durations=[30], num_channels=64, seed=0) + recording, sorting = si.generate_ground_truth_recording(durations=[15], num_channels=128, seed=0) # add inter sample shift (but fake) inter_sample_shifts = np.zeros(recording.get_num_channels()) recording.set_property("inter_sample_shift", inter_sample_shifts) @@ -62,15 +62,41 @@ def test_spikesorting(tmp_path, generate_recording): results_folder = Path(tmp_path) / "results_spikesorting" scratch_folder = Path(tmp_path) / "scratch_spikesorting" + ks25_params = Kilosort25Model(do_correction=False) + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + ) + sorting = spikesort( recording=recording, - spikesorting_params=SpikeSortingParams(), + spikesorting_params=spikesorting_params, results_folder=results_folder, scratch_folder=scratch_folder, ) assert isinstance(sorting, si.BaseSorting) + # by group + num_channels = recording.get_num_channels() + groups = [0] * (num_channels // 2) + [1] * (num_channels // 2) + recording.set_channel_groups(groups) + + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + spikesort_by_group=True, + ) + sorting_group = spikesort( + recording=recording, + spikesorting_params=spikesorting_params, + results_folder=results_folder, + scratch_folder=scratch_folder, + ) + + assert isinstance(sorting_group, si.BaseSorting) + assert "group" in sorting_group.get_property_keys() + def test_postprocessing(tmp_path, generate_recording): recording, sorting, _ = generate_recording @@ -160,13 +186,13 @@ def test_pipeline(tmp_path, generate_recording): recording, sorting, waveform_extractor = _generate_gt_recording() # print("TEST PREPROCESSING") - # test_preprocessing(tmp_folder, (recording, sorting)) - # print("TEST SPIKESORTING") - # test_spikesorting(tmp_folder, (recording, sorting)) + # test_preprocessing(tmp_folder, (recording, sorting, waveform_extractor)) + print("TEST SPIKESORTING") + test_spikesorting(tmp_folder, (recording, sorting, waveform_extractor)) # print("TEST POSTPROCESSING") - # test_postprocessing(tmp_folder, (recording, sorting)) - print("TEST CURATION") - test_curation(tmp_folder, (recording, sorting, waveform_extractor)) + # test_postprocessing(tmp_folder, (recording, sorting, waveform_extractor)) + # print("TEST CURATION") + # test_curation(tmp_folder, (recording, sorting, waveform_extractor)) # print("TEST PIPELINE") - # test_pipeline(tmp_folder, (recording, sorting)) + # test_pipeline(tmp_folder, (recording, sorting, waveform_extractor))