diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index e805f03eed..51fa94d7e6 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -136,6 +136,18 @@ Kilosort3 * See also for Matlab/CUDA: https://www.mathworks.com/help/parallel-computing/gpu-support-by-release.html +Kilosort4 +^^^^^^^^^ + +* Python, requires CUDA for GPU acceleration (highly recommended) +* Url: https://github.com/MouseLand/Kilosort +* Authors: Marius Pachitariu, Shashwat Sridhar, Carsen Stringer +* Installation:: + + pip install kilosort==4.0 torch + +* For more installation instruction refer to https://github.com/MouseLand/Kilosort + pyKilosort ^^^^^^^^^^ diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py new file mode 100644 index 0000000000..236db8bb5b --- /dev/null +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from pathlib import Path +import os +from typing import Union + +from ..basesorter import BaseSorter +from .kilosortbase import KilosortBase + +PathType = Union[str, Path] + + +class Kilosort4Sorter(BaseSorter): + """Kilosort4 Sorter object.""" + + sorter_name: str = "kilosort4" + requires_locations = True + + _default_params = { + "nblocks": 1, + "Th_universal": 9, + "Th_learned": 8, + "do_CAR": True, + "invert_sign": False, + "nt": 61, + "artifact_threshold": None, + "nskip": 25, + "whitening_range": 32, + "binning_depth": 5, + "sig_interp": 20, + "nt0min": None, + "dmin": None, + "dminx": None, + "min_template_size": 10, + "template_sizes": 5, + "nearest_chans": 10, + "nearest_templates": 100, + "templates_from_data": True, + "n_templates": 6, + "n_pcs": 6, + "Th_single_ch": 6, + "acg_threshold": 0.2, + "ccg_threshold": 0.25, + "cluster_downsampling": 20, + "cluster_pcs": 64, + "duplicate_spike_bins": 15, + "do_correction": True, + "keep_good_only": False, + "save_extra_kwargs": False, + "skip_kilosort_preprocessing": False, + "scaleproc": None, + } + + _params_description = { + "nblocks": "Number of non-overlapping blocks for drift correction (additional nblocks-1 blocks are created in the overlaps). Default value: 1.", + "Th_universal": "Spike detection threshold for universal templates. Th(1) in previous versions of Kilosort. Default value: 9.", + "Th_learned": "Spike detection threshold for learned templates. Th(2) in previous versions of Kilosort. Default value: 8.", + "do_CAR": "Whether to perform common average reference. Default value: True.", + "invert_sign": "Invert the sign of the data. Default value: False.", + "nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.", + "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", + "nskip": "Batch stride for computing whitening matrix. Default value: 25.", + "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", + "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", + "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", + "nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.", + "dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", + "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", + "min_template_size": "Standard deviation of the smallest, spatial envelope Gaussian used for universal templates. Default value: 10.", + "template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.", + "nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.", + "nearest_templates": "Number of nearest spike template locations to consider when finding local maxima during spike detection. Default value: 100.", + "templates_from_data": "Indicates whether spike shapes used in universal templates should be estimated from the data or loaded from the predefined templates. Default value: True.", + "n_templates": "Number of single-channel templates to use for the universal templates (only used if templates_from_data is True). Default value: 6.", + "n_pcs": "Number of single-channel PCs to use for extracting spike features (only used if templates_from_data is True). Default value: 6.", + "Th_single_ch": "For single channel threshold crossings to compute universal- templates. In units of whitened data standard deviations. Default value: 6.", + "acg_threshold": 'Fraction of refractory period violations that are allowed in the ACG compared to baseline; used to assign "good" units. Default value: 0.2.', + "ccg_threshold": "Fraction of refractory period violations that are allowed in the CCG compared to baseline; used to perform splits and merges. Default value: 0.25.", + "cluster_downsampling": "Inverse fraction of nodes used as landmarks during clustering (can be 1, but that slows down the optimization). Default value: 20.", + "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", + "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 15.", + "keep_good_only": "If True only 'good' units are returned", + "do_correction": "If True, drift correction is performed", + "save_extra_kwargs": "If True, additional kwargs are saved to the output", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", + "scaleproc": "int16 scaling of whitened data, if None set to 200.", + } + + sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. + The software uses new graph-based approaches to clustering that improve performance compared to previous versions. + For detailed comparisons to past versions of Kilosort and to other spike-sorting methods, please see the pre-print + at https://www.biorxiv.org/content/10.1101/2023.01.07.523036v1 + For more information see https://github.com/MouseLand/Kilosort""" + + installation_mesg = """\nTo use Kilosort4 run:\n + >>> pip install kilosort==4.0 + + More information on Kilosort4 at: + https://github.com/MouseLand/Kilosort + """ + + handle_multi_segment = False + + @classmethod + def is_installed(cls): + try: + import kilosort as ks + import torch + + HAVE_KS = True + except ImportError: + HAVE_KS = False + return HAVE_KS + + @classmethod + def get_sorter_version(cls): + import kilosort as ks + + return ks.__version__ + + @classmethod + def _setup_recording(cls, recording, sorter_output_folder, params, verbose): + from probeinterface import write_prb + + pg = recording.get_probegroup() + probe_filename = sorter_output_folder / "probe.prb" + write_prb(probe_filename, pg) + + @classmethod + def _run_from_folder(cls, sorter_output_folder, params, verbose): + from kilosort.run_kilosort import ( + set_files, + initialize_ops, + compute_preprocessing, + compute_drift_correction, + detect_spikes, + cluster_spikes, + save_sorting, + get_run_parameters, + ) + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.parameters import DEFAULT_SETTINGS + + import time + import torch + import numpy as np + + sorter_output_folder = sorter_output_folder.absolute() + + probe_filename = sorter_output_folder / "probe.prb" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # load probe + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + probe = load_probe(probe_filename) + probe_name = "" + filename = "" + + # this internally concatenates the recording + file_object = RecordingExtractorAsArray(recording) + + do_CAR = params["do_CAR"] + invert_sign = params["invert_sign"] + save_extra_vars = params["save_extra_kwargs"] + progress_bar = None + settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} + settings_ks["n_chan_bin"] = recording.get_num_channels() + settings_ks["fs"] = recording.sampling_frequency + if not do_CAR: + print("Skipping common average reference.") + + tic0 = time.time() + + settings = {**DEFAULT_SETTINGS, **settings_ks} + + if settings["nt0min"] is None: + settings["nt0min"] = int(20 * settings["nt"] / 61) + if settings["artifact_threshold"] is None: + settings["artifact_threshold"] = np.inf + + # NOTE: Also modifies settings in-place + data_dir = "" + results_dir = sorter_output_folder + filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( + get_run_parameters(ops) + ) + # Set preprocessing and drift correction parameters + if not params["skip_kilosort_preprocessing"]: + ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + else: + print("Skipping kilosort preprocessing.") + bfile = BinaryFiltered( + ops["filename"], + n_chan_bin, + fs, + NT, + nt, + twav_min, + chan_map, + hp_filter=None, + device=device, + do_CAR=do_CAR, + invert_sign=invert, + dtype=dtype, + tmin=tmin, + tmax=tmax, + artifact_threshold=artifact, + file_object=file_object, + ) + ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) + ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) + ops["Nbatches"] = bfile.n_batches + + np.random.seed(1) + torch.cuda.manual_seed_all(1) + torch.random.manual_seed(1) + # if not params["skip_kilosort_preprocessing"]: + if params["do_correction"]: + # this function applies both preprocessing and drift correction + ops, bfile, st0 = compute_drift_correction( + ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ) + else: + print("Skipping drift correction.") + hp_filter = ops["preprocessing"]["hp_filter"] + whiten_mat = ops["preprocessing"]["whiten_mat"] + + bfile = BinaryFiltered( + ops["filename"], + n_chan_bin, + fs, + NT, + nt, + twav_min, + chan_map, + hp_filter=hp_filter, + whiten_mat=whiten_mat, + device=device, + do_CAR=do_CAR, + invert_sign=invert, + dtype=dtype, + tmin=tmin, + tmax=tmax, + artifact_threshold=artifact, + file_object=file_object, + ) + + # TODO: don't think we need to do this actually + # Save intermediate `ops` for use by GUI plots + # io.save_ops(ops, results_dir) + + # Sort spikes and save results + st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + if params["skip_kilosort_preprocessing"]: + ops["preprocessing"] = dict( + hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) + ) + ops, similar_templates, is_ref, est_contam_rate = save_sorting( + ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars + ) + + # # Clean-up temporary files + # if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): + # recording_file.unlink() + + # all_tmp_files = ("matlab_files", "temp_wh.dat") + + # if isinstance(params["delete_tmp_files"], bool): + # if params["delete_tmp_files"]: + # tmp_files_to_remove = all_tmp_files + # else: + # tmp_files_to_remove = () + # else: + # assert isinstance( + # params["delete_tmp_files"], (tuple, list) + # ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`." + + # for name in params["delete_tmp_files"]: + # assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}" + + # tmp_files_to_remove = params["delete_tmp_files"] + + # if "temp_wh.dat" in tmp_files_to_remove: + # if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): + # temp_wh_file.unlink() + + # if "matlab_files" in tmp_files_to_remove: + # for ext in ["*.m", "*.mat"]: + # for temp_file in sorter_output_folder.glob(ext): + # temp_file.unlink() + + @classmethod + def _get_result_from_folder(cls, sorter_output_folder): + return KilosortBase._get_result_from_folder(sorter_output_folder) diff --git a/src/spikeinterface/sorters/external/tests/test_kilosort4.py b/src/spikeinterface/sorters/external/tests/test_kilosort4.py new file mode 100644 index 0000000000..87346d1dbb --- /dev/null +++ b/src/spikeinterface/sorters/external/tests/test_kilosort4.py @@ -0,0 +1,138 @@ +import unittest +import pytest +from pathlib import Path + +from spikeinterface import load_extractor +from spikeinterface.extractors import toy_example +from spikeinterface.sorters import Kilosort4Sorter, run_sorter +from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "sorters" +else: + cache_folder = Path("cache_folder") / "sorters" + + +# This run several tests +@pytest.mark.skipif(not Kilosort4Sorter.is_installed(), reason="kilosort4 not installed") +class Kilosort4SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): + SorterClass = Kilosort4Sorter + + # 4 channels is to few for KS4 + def setUp(self): + if (cache_folder / "rec").is_dir(): + recording = load_extractor(cache_folder / "rec") + else: + recording, _ = toy_example(num_channels=32, duration=60, seed=0, num_segments=1) + recording = recording.save(folder=cache_folder / "rec", verbose=False, format="binary") + self.recording = recording + print(self.recording) + + def test_with_run_skip_correction(self): + recording = self.recording + + sorter_name = self.SorterClass.sorter_name + + output_folder = cache_folder / sorter_name + + sorter_params = self.SorterClass.default_params() + sorter_params["do_correction"] = False + + sorting = run_sorter( + sorter_name, + recording, + output_folder=output_folder, + remove_existing_folder=True, + delete_output_folder=True, + verbose=False, + raise_error=True, + **sorter_params, + ) + assert sorting.sorting_info is not None + assert "recording" in sorting.sorting_info.keys() + assert "params" in sorting.sorting_info.keys() + assert "log" in sorting.sorting_info.keys() + + del sorting + # test correct deletion of sorter folder, but not run metadata + assert not (output_folder / "sorter_output").is_dir() + assert (output_folder / "spikeinterface_recording.json").is_file() + assert (output_folder / "spikeinterface_params.json").is_file() + assert (output_folder / "spikeinterface_log.json").is_file() + + def test_with_run_skip_preprocessing(self): + from spikeinterface.preprocessing import whiten + + recording = self.recording + + sorter_name = self.SorterClass.sorter_name + + output_folder = cache_folder / sorter_name + + sorter_params = self.SorterClass.default_params() + sorter_params["skip_kilosort_preprocessing"] = True + recording = whiten(recording) + + sorting = run_sorter( + sorter_name, + recording, + output_folder=output_folder, + remove_existing_folder=True, + delete_output_folder=True, + verbose=False, + raise_error=True, + **sorter_params, + ) + assert sorting.sorting_info is not None + assert "recording" in sorting.sorting_info.keys() + assert "params" in sorting.sorting_info.keys() + assert "log" in sorting.sorting_info.keys() + + del sorting + # test correct deletion of sorter folder, but not run metadata + assert not (output_folder / "sorter_output").is_dir() + assert (output_folder / "spikeinterface_recording.json").is_file() + assert (output_folder / "spikeinterface_params.json").is_file() + assert (output_folder / "spikeinterface_log.json").is_file() + + def test_with_run_skip_preprocessing_and_correction(self): + from spikeinterface.preprocessing import whiten + + recording = self.recording + + sorter_name = self.SorterClass.sorter_name + + output_folder = cache_folder / sorter_name + + sorter_params = self.SorterClass.default_params() + sorter_params["skip_kilosort_preprocessing"] = True + sorter_params["do_correction"] = False + recording = whiten(recording) + + sorting = run_sorter( + sorter_name, + recording, + output_folder=output_folder, + remove_existing_folder=True, + delete_output_folder=True, + verbose=False, + raise_error=True, + **sorter_params, + ) + assert sorting.sorting_info is not None + assert "recording" in sorting.sorting_info.keys() + assert "params" in sorting.sorting_info.keys() + assert "log" in sorting.sorting_info.keys() + + del sorting + # test correct deletion of sorter folder, but not run metadata + assert not (output_folder / "sorter_output").is_dir() + assert (output_folder / "spikeinterface_recording.json").is_file() + assert (output_folder / "spikeinterface_params.json").is_file() + assert (output_folder / "spikeinterface_log.json").is_file() + + +if __name__ == "__main__": + test = Kilosort4SorterCommonTestSuite() + test.setUp() + test.test_with_run_skip_preprocessing_and_correction() diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py index b164f16c43..8032826172 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py @@ -57,6 +57,12 @@ def test_kilosort3(run_kwargs): print(sorting) +def test_kilosort4(run_kwargs): + clean_singularity_cache() + sorting = ss.run_sorter(sorter_name="kilosort4", output_folder="kilosort4", **run_kwargs) + print(sorting) + + def test_pykilosort(run_kwargs): clean_singularity_cache() sorting = ss.run_sorter(sorter_name="pykilosort", output_folder="pykilosort", **run_kwargs) @@ -72,4 +78,4 @@ def test_yass(run_kwargs): if __name__ == "__main__": kwargs = generate_run_kwargs() - test_pykilosort(kwargs) + test_kilosort4(kwargs) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index d633f24989..66a6138de7 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -34,6 +34,7 @@ SORTER_DOCKER_MAP = dict( combinato="combinato", herdingspikes="herdingspikes", + kilosort4="kilosort4", klusta="klusta", mountainsort4="mountainsort4", mountainsort5="mountainsort5", diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 47557423f6..10d0421f87 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -8,6 +8,7 @@ from .external.kilosort2 import Kilosort2Sorter from .external.kilosort2_5 import Kilosort2_5Sorter from .external.kilosort3 import Kilosort3Sorter +from .external.kilosort4 import Kilosort4Sorter from .external.pykilosort import PyKilosortSorter from .external.klusta import KlustaSorter from .external.mountainsort4 import Mountainsort4Sorter @@ -32,6 +33,7 @@ Kilosort2Sorter, Kilosort2_5Sorter, Kilosort3Sorter, + Kilosort4Sorter, PyKilosortSorter, KlustaSorter, Mountainsort4Sorter,