diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..92e7bf277f --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,35 @@ +import os +import re +from pathlib import Path +import requests +import json +from packaging.version import parse +import spikeinterface + +def get_pypi_versions(package_name): + """ + Make an API call to pypi to retrieve all + available versions of the kilosort package. + """ + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + versions = list(sorted(data["releases"].keys())) + + assert parse(spikeinterface.__version__) < parse("0.101.1"), ( + "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." + "At version 0.101.1, this should be updated to support newer" + "kilosort verrsions." + ) + versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] + return versions + + +if __name__ == "__main__": + # Get all KS4 versions from pipi and write to file. + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + print(versions) + json.dump(versions, f) diff --git a/.github/scripts/kilosort4-latest-version.json b/.github/scripts/kilosort4-latest-version.json new file mode 100644 index 0000000000..03629ff842 --- /dev/null +++ b/.github/scripts/kilosort4-latest-version.json @@ -0,0 +1 @@ +["4.0.10", "4.0.11", "4.0.12", "4.0.5", "4.0.6", "4.0.7", "4.0.8", "4.0.9"] diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py new file mode 100644 index 0000000000..e0d1f2a504 --- /dev/null +++ b/.github/scripts/test_kilosort4_ci.py @@ -0,0 +1,630 @@ +""" +This file tests the SpikeInterface wrapper of the Kilosort4. The general logic +of the tests are: +- Change every exposed parameter one at a time (PARAMS_TO_TEST). Check that + the result of the SpikeInterface wrapper and Kilosort run natively are + the same. The SpikeInterface wrapper is non-trivial and decomposes the + kilosort pipeline to allow additions such as skipping preprocessing. Therefore, + the idea is that is it safer to rely on the output directly rather than + try monkeypatching. One thing can could be better tested is parameter + changes when skipping KS4 preprocessing is true, because this takes a slightly + different path through the kilosort4.py wrapper logic. + This also checks that changing the parameter changes the test output from default + on our test case (otherwise, the test could not detect a failure). This is possible + for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + +- Test that kilosort functions called from `kilosort4.py` wrapper have the expected + input signatures + +- Do some tests to check all KS4 parameters are tested against. +""" +import copy +from typing import Any +import spikeinterface.full as si +import numpy as np +import torch +import kilosort +from kilosort.io import load_probe +import pandas as pd +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter +import pytest +from probeinterface.io import write_prb +from kilosort.parameters import DEFAULT_SETTINGS +from packaging.version import parse +from importlib.metadata import version +from inspect import signature +from kilosort.run_kilosort import (set_files, initialize_ops, + compute_preprocessing, + compute_drift_correction, detect_spikes, + cluster_spikes, save_sorting, + get_run_parameters, ) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered +from kilosort.parameters import DEFAULT_SETTINGS +from kilosort import preprocessing as ks_preprocessing + +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] +# "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. + +# Setup Params to test #### +PARAMS_TO_TEST = [ + # Not tested + # ("torch_device", "auto") + + # Stable across KS version 4.0.01 - 4.0.12 + ("change_nothing", None), + ("nblocks", 0), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_universal", 12), + ("Th_learned", 14), + ("invert_sign", True), + ("nt", 93), + ("nskip", 1), + ("whitening_range", 16), + ("sig_interp", 5), + ("nt0min", 25), + ("dmin", 15), + ("dminx", 16), + ("min_template_size", 15), + ("template_sizes", 10), + ("nearest_chans", 8), + ("nearest_templates", 35), + ("max_channel_distance", 5), + ("templates_from_data", False), + ("n_templates", 10), + ("n_pcs", 3), + ("Th_single_ch", 4), + ("x_centers", 5), + ("binning_depth", 1), + # Note: These don't change the results from + # default when applied to the test case. + ("artifact_threshold", 200), + ("ccg_threshold", 1e12), + ("acg_threshold", 1e12), + ("cluster_downsampling", 2), + ("duplicate_spike_bins", 5), +] + +if parse(version("kilosort")) >= parse("4.0.11"): + PARAMS_TO_TEST.extend( + [ + ("shift", 1e9), + ("scale", -1e9), + ] + ) +if parse(version("kilosort")) == parse("4.0.9"): + # bug in 4.0.9 for "nblocks=0" + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] + +if parse(version("kilosort")) >= parse("4.0.8"): + PARAMS_TO_TEST.extend( + [ + ("drift_smoothing", [250, 250, 250]), + ] + ) +if parse(version("kilosort")) <= parse("4.0.6"): + # AFAIK this parameter was always unused in KS (that's why it was removed) + PARAMS_TO_TEST.extend( + [ + ("cluster_pcs", 1e9), + ] + ) +if parse(version("kilosort")) <= parse("4.0.3"): + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] + + +class TestKilosort4Long: + + # Fixtures ###### + @pytest.fixture(scope="session") + def recording_and_paths(self, tmp_path_factory): + """ + Create a ground-truth recording, and save it to binary + so KS4 can run on it. Fixture is set up once and shared between + all tests. + """ + tmp_path = tmp_path_factory.mktemp("kilosort4_tests") + + recording = self._get_ground_truth_recording() + + paths = self._save_ground_truth_recording(recording, tmp_path) + + return (recording, paths) + + @pytest.fixture(scope="session") + def default_results(self, recording_and_paths): + """ + Because we check each parameter at a time and check the + KS4 and SpikeInterface versions match, if changing the parameter + had no effect as compared to default then the test would not test + anything. Therefore, the default results are run once and stored, + to check changing params indeed changes the results during testing. + This is possibly for nearly all parameters. + """ + recording, paths = recording_and_paths + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) + + defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=defaults_ks_output_dir, + ) + + default_results = self._get_sorting_output(defaults_ks_output_dir) + + return default_results + + def _get_ground_truth_recording(self): + """ + A ground truth recording chosen to be as small as possible (for speed). + But contain enough information so that changing most parameters + changes the results. + """ + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording + + def _save_ground_truth_recording(self, recording, tmp_path): + """ + Save the recording and its probe to file, so it can be + loaded by KS4. + """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } + + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths + + # Tests ###### + def test_params_to_test(self): + """ + Test that all values in PARAMS_TO_TEST are + different to the default values used in Kilosort, + otherwise there is no point to the test. + """ + for parameter in PARAMS_TO_TEST: + + param_key, param_value = parameter + + if param_key == "change_nothing": + continue + + if param_key not in RUN_KILOSORT_ARGS: + assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + + def test_default_settings_all_represented(self): + """ + Test that every entry in DEFAULT_SETTINGS is tested in + PARAMS_TO_TEST, otherwise we are missing settings added + on the KS side. + """ + tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + + for param_key in DEFAULT_SETTINGS: + + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": + continue + assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + + def test_spikeinterface_defaults_against_kilsort(self): + """ + Here check that all _ + Don't check that every default in KS is exposed in params, + because they change across versions. Instead, this check + is performed here against PARAMS_TO_TEST. + """ + params = copy.deepcopy(Kilosort4Sorter._default_params) + + for key in params.keys(): + # "artifact threshold" is set to `np.inf` if `None` in + # the body of the `Kilosort4Sorter` class. + if key in DEFAULT_SETTINGS and key not in ["artifact_threshold"]: + assert params[key] == DEFAULT_SETTINGS[key], f"{key} is not the same across versions." + + # Testing Arguments ### + def test_set_files_arguments(self): + self._check_arguments( + set_files, + ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + ) + + def test_initialize_ops_arguments(self): + expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] + + if parse(version("kilosort")) >= parse("4.0.12"): + expected_arguments.append("save_preprocesed_copy") + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments( + compute_preprocessing, + ["ops", "device", "tic0", "file_object"] + ) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, + ["ops", "device", "tic0", "progress_bar", "file_object"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments( + detect_spikes, + ["ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, + ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_save_sorting_arguments(self): + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + if parse(version("kilosort")) > parse("4.0.11"): + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments( + save_sorting, + expected_arguments + ) + + def test_get_run_parameters(self): + self._check_arguments( + get_run_parameters, + ["ops"] + ) + + def test_load_probe_parameters(self): + self._check_arguments( + load_probe, + ["probe_path"] + ) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments( + RecordingExtractorAsArray, + ["recording_extractor"] + ) + + def test_binary_filtered_arguments(self): + expected_arguments = [ + "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", + "chan_map", "hp_filter", "whiten_mat", "dshift", + "device", "do_CAR", "artifact_threshold", "invert_sign", + "dtype", "tmin", "tmax", "file_object" + ] + + if parse(version("kilosort")) >= parse("4.0.11"): + expected_arguments.pop(-1) + expected_arguments.extend(["shift", "scale", "file_object"]) + + self._check_arguments( + BinaryFiltered, + expected_arguments + ) + + def _check_arguments(self, object_, expected_arguments): + """ + Check that the argument signature of `object_` is as expected + (i..e has not changed across kilosort versions). + """ + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + + # Full Test #### + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) + def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): + """ + Given a recording, paths to raw data, and a parameter to change, + run KS4 natively and within the SpikeInterface wrapper with the + new parameter value (all other values default) and + check the outputs are the same. + """ + recording, paths = recording_and_paths + param_key, param_value = parameter + + # Setup parameters for KS4 and run it natively + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + **run_kilosort_kwargs, + ) + + # Setup Parameters for SI and KS4 through SI + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) + + # Get the results and check they match + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" + + # Check the ops file in KS4 output is as expected. This is saved on the + # SI side so not an extremely robust addition, but it can't hurt. + if param_key != "change_nothing": + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value + + # Finally, check out test parameters actually change the output of + # KS4, ensuring our tests are actually doing something. This is not + # done prior to 4.0.4 because a number of parameters seem to stop + # having an effect. This is probably due to small changes in their + # behaviour, and the test file chosen here. + if parse(version("kilosort")) > parse("4.0.4"): + self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): + """ + Test the SpikeInterface wrappers `do_correction` argument. We set + `nblocks=0` for KS4 native, turning off motion correction. Then + we run KS$ through SpikeInterface with `do_correction=False` but + `nblocks=1` (KS4 default) - checking that `do_correction` overrides + this and the result matches KS4 when run without motion correction. + """ + recording, paths = recording_and_paths + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "nblocks", 0) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=True, + ) + + spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + @pytest.mark.parametrize("param_to_test", [ + ("change_nothing", None), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ]) + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): + """ + Test that skipping KS4 preprocessing works as expected. Run + KS4 natively, monkeypatching the relevant preprocessing functions + such that preprocessing is not performed. Then run in SpikeInterface + with `skip_kilosort_preprocessing=True` and check the outputs match. + + Run with a few randomly chosen parameters to check these are propagated + under this condition. + + TODO + ---- + It would be nice to check a few additional parameters here. Screw it! + """ + param_key, param_value = param_to_test + + recording = self._get_ground_truth_recording() + + # We need to filter and whiten the recording here to KS takes forever. + # Do this in a way different to KS. + recording = si.highpass_filter(recording, 300) + recording = si.whiten(recording, mode="local", apply_mean=False) + + paths = self._save_ground_truth_recording(recording, tmp_path) + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] + + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass + + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass + return X + + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", + monkeypatch_filter_function) + + ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + ks_settings["nblocks"] = 0 + + # Be explicit here and don't rely on defaults. + do_CAR = param_value if param_key == "do_CAR" else False + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=do_CAR, + ) + + monkeypatch.undo() + + # Now, run kilosort through spikeinterface with the same options. + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + spikeinterface_settings["nblocks"] = 0 + + do_CAR = False if param_key != "do_CAR" else spikeinterface_settings.pop("do_CAR") + + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=do_CAR, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) + + # There is a very slight difference caused by the batching between load vs. + # memory file. Because in this test recordings are preprocessed, there are + # some filter edge effects that depend on the chunking in `get_traces()`. + # These are all extremely close (usually just 1 spike, 1 idx different). + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) + + # Helpers ###### + def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): + """ + If nothing is changed, default vs. results outputs are identical. + Otherwise, check they are not the same. Can't figure out how to get + the skipped three parameters below to change the results on this + small test file. + """ + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: + return + + if param_key == "change_nothing": + assert all( + default_results["ks"]["st"] == results["ks"]["st"] + ) and all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} changed somehow!." + else: + assert not ( + default_results["ks"]["st"].size == results["ks"]["st"].size + ) or not all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} results did not change with parameter change." + + def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): + """ + Function to generate the settings and function inputs to run kilosort. + Note when `binning_depth` is used we need to set `nblocks` high to + get the results to change from default. + + Some settings in KS4 are passed by `settings` dict while others + are through the function, these are split here. + """ + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} + + ks_format_probe = load_probe(paths["probe_path"]) + + return settings, run_kilosort_kwargs, ks_format_probe + + def _get_spikeinterface_settings(self, param_key, param_value): + """ + Generate settings kwargs for running KS4 in SpikeInterface. + See `_get_kilosort_native_settings()` for some details. + """ + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) + + if param_key != "change_nothing": + settings.update({param_key: param_value}) + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # settings.pop(name) + + return settings + + def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: + """ + Load the results of sorting into a dict for easy comparison. + """ + results = { + "si": {}, + "ks": {}, + } + if kilosort_output_dir: + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + + if spikeinterface_output_dir: + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") + + return results diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml new file mode 100644 index 0000000000..390bec98be --- /dev/null +++ b/.github/workflows/test_kilosort4.yml @@ -0,0 +1,70 @@ +name: Testing Kilosort4 + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + +jobs: + versions: + # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # and store them in a matrix for the next job. + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install requests packaging + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install SpikeInterface + run: | + pip install -e .[test] + shell: bash + + - name: Install Kilosort + run: | + pip install kilosort==${{ matrix.ks_version }} + shell: bash + + - name: Run new kilosort4 tests + run: | + pytest .github/scripts/test_kilosort4_ci.py + shell: bash diff --git a/conftest.py b/conftest.py index ce5e07b47b..5bf7d74527 100644 --- a/conftest.py +++ b/conftest.py @@ -7,6 +7,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -16,7 +17,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 6d83249653..8499cef11f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -94,7 +95,6 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", @@ -131,9 +131,8 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - import kilosort as ks - - return ks.__version__ + """kilosort.__version__ <4.0.10 is always '4'""" + return importlib_version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @@ -167,6 +166,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): + raise RuntimeError( + "Kilosort versions before 4.0.5 are not supported" + "in SpikeInterface. " + "Please upgrade Kilosort version." + ) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -178,12 +184,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - probe = load_probe(probe_filename) + probe = load_probe(probe_path=probe_filename) probe_name = "" filename = "" # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording) + file_object = RecordingExtractorAsArray(recording_extractor=recording) do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] @@ -208,39 +214,58 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + + filename, data_dir, results_dir, probe = set_files( + settings=settings, + filename=filename, + probe=probe, + probe_name=probe_name, + data_dir=data_dir, + results_dir=results_dir, + ) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): ops = initialize_ops( - settings, - probe, - recording.get_dtype(), - do_CAR, - invert_sign, - device, + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) ) + else: + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + ) + + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) else: - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, + filename=ops["filename"], + n_chan_bin=n_chan_bin, + fs=fs, + NT=NT, + nt=nt, + nt0min=twav_min, + chan_map=chan_map, hp_filter=None, device=device, do_CAR=do_CAR, @@ -254,26 +279,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches + # bfile.close() # TODO: KS do this after preprocessing? np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) - # if not params["skip_kilosort_preprocessing"]: + if not params["do_correction"]: print("Skipping drift correction.") ops["nblocks"] = 0 # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) if save_preprocessed_copy: save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + + clu, Wall = cluster_spikes( + st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + ) + if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) @@ -281,14 +311,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): _ = save_sorting( - ops, - results_dir, - st, - clu, - tF, - Wall, - bfile.imin, - tic0, + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, save_extra_vars=save_extra_vars, save_preprocessed_copy=save_preprocessed_copy, )