diff --git a/.github/run_tests.sh b/.github/run_tests.sh index 558e0b64d3..02eb6ab8a1 100644 --- a/.github/run_tests.sh +++ b/.github/run_tests.sh @@ -10,5 +10,5 @@ fi pytest -m "$MARKER" -vv -ra --durations=0 --durations-min=0.001 | tee report.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of ${MARKER}" >> $GITHUB_STEP_SUMMARY -python $GITHUB_WORKSPACE/.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY +python $GITHUB_WORKSPACE/.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY rm report.txt diff --git a/.github/build_job_summary.py b/.github/scripts/build_job_summary.py similarity index 100% rename from .github/build_job_summary.py rename to .github/scripts/build_job_summary.py diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 92e7bf277f..7a6368f3cf 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -16,13 +16,8 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) - - assert parse(spikeinterface.__version__) < parse("0.101.1"), ( - "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." - "At version 0.101.1, this should be updated to support newer" - "kilosort verrsions." - ) - versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] + # Filter out versions that are less than 4.0.16 + versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")] return versions diff --git a/.github/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py similarity index 100% rename from .github/determine_testing_environment.py rename to .github/scripts/determine_testing_environment.py diff --git a/.github/import_test.py b/.github/scripts/import_test.py similarity index 100% rename from .github/import_test.py rename to .github/scripts/import_test.py diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index e0d1f2a504..6eeb71f1dd 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -10,111 +10,101 @@ changes when skipping KS4 preprocessing is true, because this takes a slightly different path through the kilosort4.py wrapper logic. This also checks that changing the parameter changes the test output from default - on our test case (otherwise, the test could not detect a failure). This is possible - for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + on our test case (otherwise, the test could not detect a failure). - Test that kilosort functions called from `kilosort4.py` wrapper have the expected input signatures - Do some tests to check all KS4 parameters are tested against. """ + +import pytest import copy from typing import Any -import spikeinterface.full as si +from inspect import signature + import numpy as np import torch -import kilosort -from kilosort.io import load_probe -import pandas as pd + +import spikeinterface.full as si +from spikeinterface.core.testing import check_sortings_equal from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter -import pytest from probeinterface.io import write_prb + +import kilosort from kilosort.parameters import DEFAULT_SETTINGS -from packaging.version import parse -from importlib.metadata import version -from inspect import signature -from kilosort.run_kilosort import (set_files, initialize_ops, - compute_preprocessing, - compute_drift_correction, detect_spikes, - cluster_spikes, save_sorting, - get_run_parameters, ) +from kilosort.run_kilosort import ( + set_files, + initialize_ops, + compute_preprocessing, + compute_drift_correction, + detect_spikes, + cluster_spikes, + save_sorting, + get_run_parameters, +) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered -from kilosort.parameters import DEFAULT_SETTINGS -from kilosort import preprocessing as ks_preprocessing + RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. # Setup Params to test #### -PARAMS_TO_TEST = [ - # Not tested - # ("torch_device", "auto") - - # Stable across KS version 4.0.01 - 4.0.12 - ("change_nothing", None), - ("nblocks", 0), - ("do_CAR", False), - ("batch_size", 42743), - ("Th_universal", 12), - ("Th_learned", 14), - ("invert_sign", True), - ("nt", 93), - ("nskip", 1), - ("whitening_range", 16), - ("sig_interp", 5), - ("nt0min", 25), - ("dmin", 15), - ("dminx", 16), - ("min_template_size", 15), - ("template_sizes", 10), - ("nearest_chans", 8), - ("nearest_templates", 35), - ("max_channel_distance", 5), - ("templates_from_data", False), - ("n_templates", 10), - ("n_pcs", 3), - ("Th_single_ch", 4), - ("x_centers", 5), - ("binning_depth", 1), - # Note: These don't change the results from - # default when applied to the test case. - ("artifact_threshold", 200), - ("ccg_threshold", 1e12), - ("acg_threshold", 1e12), - ("cluster_downsampling", 2), - ("duplicate_spike_bins", 5), +PARAMS_TO_TEST_DICT = { + "nblocks": 0, + "do_CAR": False, + "batch_size": 42743, + "Th_universal": 12, + "Th_learned": 14, + "invert_sign": True, + "nt": 93, + "nskip": 1, + "whitening_range": 16, + "highpass_cutoff": 200, + "sig_interp": 5, + "nt0min": 25, + "dmin": 15, + "dminx": 16, + "min_template_size": 15, + "template_sizes": 10, + "nearest_chans": 8, + "nearest_templates": 35, + "max_channel_distance": 5, + "templates_from_data": False, + "n_templates": 10, + "n_pcs": 3, + "Th_single_ch": 4, + "x_centers": 5, + "binning_depth": 1, + "drift_smoothing": [250, 250, 250], + "artifact_threshold": 200, + "ccg_threshold": 1e12, + "acg_threshold": 1e12, + "cluster_downsampling": 2, + "duplicate_spike_ms": 0.3, +} + +PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) + +PARAMETERS_NOT_AFFECTING_RESULTS = [ + "artifact_threshold", + "ccg_threshold", + "acg_threshold", + "cluster_downsampling", + "cluster_pcs", + "duplicate_spike_ms", # this is because ground-truth spikes don't have violations ] -if parse(version("kilosort")) >= parse("4.0.11"): - PARAMS_TO_TEST.extend( - [ - ("shift", 1e9), - ("scale", -1e9), - ] - ) -if parse(version("kilosort")) == parse("4.0.9"): - # bug in 4.0.9 for "nblocks=0" - PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] - -if parse(version("kilosort")) >= parse("4.0.8"): - PARAMS_TO_TEST.extend( - [ - ("drift_smoothing", [250, 250, 250]), - ] - ) -if parse(version("kilosort")) <= parse("4.0.6"): - # AFAIK this parameter was always unused in KS (that's why it was removed) - PARAMS_TO_TEST.extend( - [ - ("cluster_pcs", 1e9), - ] - ) -if parse(version("kilosort")) <= parse("4.0.3"): - PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] +# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST +# if parse(version("kilosort")) >= parse("4.0.X"): +# PARAMS_TO_TEST_DICT.update( +# [ +# {"new_param": new_value}, +# ] +# ) class TestKilosort4Long: - # Fixtures ###### @pytest.fixture(scope="session") def recording_and_paths(self, tmp_path_factory): @@ -132,7 +122,7 @@ def recording_and_paths(self, tmp_path_factory): return (recording, paths) @pytest.fixture(scope="session") - def default_results(self, recording_and_paths): + def default_kilosort_sorting(self, recording_and_paths): """ Because we check each parameter at a time and check the KS4 and SpikeInterface versions match, if changing the parameter @@ -143,7 +133,7 @@ def default_results(self, recording_and_paths): """ recording, paths = recording_and_paths - settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, None, None) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -154,9 +144,7 @@ def default_results(self, recording_and_paths): results_dir=defaults_ks_output_dir, ) - default_results = self._get_sorting_output(defaults_ks_output_dir) - - return default_results + return si.read_kilosort(defaults_ks_output_dir) def _get_ground_truth_recording(self): """ @@ -195,19 +183,16 @@ def _save_ground_truth_recording(self, recording, tmp_path): # Tests ###### def test_params_to_test(self): """ - Test that all values in PARAMS_TO_TEST are + Test that all values in PARAMS_TO_TEST_DICT are different to the default values used in Kilosort, otherwise there is no point to the test. """ - for parameter in PARAMS_TO_TEST: - - param_key, param_value = parameter - - if param_key == "change_nothing": - continue - + for param_key, param_value in PARAMS_TO_TEST_DICT.items(): if param_key not in RUN_KILOSORT_ARGS: - assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + assert DEFAULT_SETTINGS[param_key] != param_value, ( + f"{param_key} values should be different in test: " + f"{param_value} vs. {DEFAULT_SETTINGS[param_key]}" + ) def test_default_settings_all_represented(self): """ @@ -215,13 +200,12 @@ def test_default_settings_all_represented(self): PARAMS_TO_TEST, otherwise we are missing settings added on the KS side. """ - tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + tested_keys = PARAMS_TO_TEST + additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy"] + tested_keys += additional_non_tested_keys for param_key in DEFAULT_SETTINGS: - if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: - if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": - continue assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." def test_spikeinterface_defaults_against_kilsort(self): @@ -242,15 +226,19 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): self._check_arguments( - set_files, - ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] ) def test_initialize_ops_arguments(self): - expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] - - if parse(version("kilosort")) >= parse("4.0.12"): - expected_arguments.append("save_preprocesed_copy") + expected_arguments = [ + "settings", + "probe", + "data_dtype", + "do_CAR", + "invert_sign", + "device", + "save_preprocessed_copy", + ] self._check_arguments( initialize_ops, @@ -258,79 +246,67 @@ def test_initialize_ops_arguments(self): ) def test_compute_preprocessing_arguments(self): - self._check_arguments( - compute_preprocessing, - ["ops", "device", "tic0", "file_object"] - ) + self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): self._check_arguments( - compute_drift_correction, - ["ops", "device", "tic0", "progress_bar", "file_object"] + compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] ) def test_detect_spikes_arguments(self): - self._check_arguments( - detect_spikes, - ["ops", "device", "bfile", "tic0", "progress_bar"] - ) + self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_cluster_spikes_arguments(self): self._check_arguments( - cluster_spikes, - ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] ) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] - if parse(version("kilosort")) > parse("4.0.11"): - expected_arguments.append("save_preprocessed_copy") + expected_arguments.append("save_preprocessed_copy") - self._check_arguments( - save_sorting, - expected_arguments - ) + self._check_arguments(save_sorting, expected_arguments) def test_get_run_parameters(self): - self._check_arguments( - get_run_parameters, - ["ops"] - ) + self._check_arguments(get_run_parameters, ["ops"]) def test_load_probe_parameters(self): - self._check_arguments( - load_probe, - ["probe_path"] - ) + self._check_arguments(load_probe, ["probe_path"]) def test_recording_extractor_as_array_arguments(self): - self._check_arguments( - RecordingExtractorAsArray, - ["recording_extractor"] - ) + self._check_arguments(RecordingExtractorAsArray, ["recording_extractor"]) def test_binary_filtered_arguments(self): expected_arguments = [ - "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", - "chan_map", "hp_filter", "whiten_mat", "dshift", - "device", "do_CAR", "artifact_threshold", "invert_sign", - "dtype", "tmin", "tmax", "file_object" + "filename", + "n_chan_bin", + "fs", + "NT", + "nt", + "nt0min", + "chan_map", + "hp_filter", + "whiten_mat", + "dshift", + "device", + "do_CAR", + "artifact_threshold", + "invert_sign", + "dtype", + "tmin", + "tmax", + "shift", + "scale", + "file_object", ] - if parse(version("kilosort")) >= parse("4.0.11"): - expected_arguments.pop(-1) - expected_arguments.extend(["shift", "scale", "file_object"]) - - self._check_arguments( - BinaryFiltered, - expected_arguments - ) + self._check_arguments(BinaryFiltered, expected_arguments) def _check_arguments(self, object_, expected_arguments): """ Check that the argument signature of `object_` is as expected - (i..e has not changed across kilosort versions). + (i.e. has not changed across kilosort versions). """ sig = signature(object_) obj_arguments = list(sig.parameters.keys()) @@ -338,7 +314,7 @@ def _check_arguments(self, object_, expected_arguments): # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) - def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): + def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp_path, parameter): """ Given a recording, paths to raw data, and a parameter to change, run KS4 natively and within the SpikeInterface wrapper with the @@ -346,13 +322,16 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa check the outputs are the same. """ recording, paths = recording_and_paths - param_key, param_value = parameter + param_key = parameter + param_value = PARAMS_TO_TEST_DICT[param_key] # Setup parameters for KS4 and run it natively kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings( + recording, paths, param_key, param_value + ) kilosort.run_kilosort( settings=settings, @@ -361,11 +340,12 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa results_dir=kilosort_output_dir, **run_kilosort_kwargs, ) + sorting_ks4 = si.read_kilosort(kilosort_output_dir) # Setup Parameters for SI and KS4 through SI spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -374,27 +354,44 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa ) # Get the results and check they match - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - - assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" + check_sortings_equal(sorting_ks4, sorting_si) # Check the ops file in KS4 output is as expected. This is saved on the # SI side so not an extremely robust addition, but it can't hurt. - if param_key != "change_nothing": - ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) - ops = ops.tolist() # strangely this makes a dict - assert ops[param_key] == param_value + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something. This is not - # done prior to 4.0.4 because a number of parameters seem to stop - # having an effect. This is probably due to small changes in their - # behaviour, and the test file chosen here. - if parse(version("kilosort")) > parse("4.0.4"): - self._check_test_parameters_are_changing_the_output(results, default_results, param_key) - - @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + # KS4, ensuring our tests are actually doing something (exxcept for some params). + if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: + with pytest.raises(AssertionError): + check_sortings_equal(default_kilosort_sorting, sorting_si) + + def test_clear_cache(self,recording_and_paths, tmp_path): + """ + Test clear_cache parameter in kilosort4.run_kilosort + """ + recording, paths = recording_and_paths + + spikeinterface_output_dir = tmp_path / "spikeinterface_output_clear" + sorting_si_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=True + ) + spikeinterface_output_dir = tmp_path / "spikeinterface_output_no_clear" + sorting_si_no_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=False + ) + check_sortings_equal(sorting_si_clear, sorting_si_no_clear) + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set @@ -417,9 +414,10 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): results_dir=kilosort_output_dir, do_CAR=True, ) + sorting_ks = si.read_kilosort(kilosort_output_dir) spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -427,22 +425,72 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): do_correction=False, **spikeinterface_settings, ) + check_sortings_equal(sorting_ks, sorting_si) - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + def test_use_binary_file(self, tmp_path): + """ + Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly + from the recording. + """ + recording = self._get_ground_truth_recording() + recording_bin = recording.save() - assert np.array_equal(results["ks"]["st"], results["si"]["st"]) - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + # run with SI wrapper + sorting_ks4 = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_si_wrapper_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_bin = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_bin_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_force_binary = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin", + use_binary_file=True, + remove_existing_folder=True, + ) + assert not (tmp_path / "ks4_output_force_bin" / "sorter_output" / "recording.dat").exists() + sorting_ks4_force_non_binary = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_force_wrapper", + use_binary_file=False, + remove_existing_folder=True, + ) + # test deleting recording.dat + sorting_ks4_force_binary_keep = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin_keep", + use_binary_file=True, + delete_recording_dat=False, + remove_existing_folder=True, + ) + assert (tmp_path / "ks4_output_force_bin_keep" / "sorter_output" / "recording.dat").exists() + + check_sortings_equal(sorting_ks4, sorting_ks4_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_force_binary) + check_sortings_equal(sorting_ks4, sorting_ks4_force_non_binary) - @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") - @pytest.mark.parametrize("param_to_test", [ - ("change_nothing", None), - ("do_CAR", False), - ("batch_size", 42743), - ("Th_learned", 14), - ("dmin", 15), - ("max_channel_distance", 5), - ("n_pcs", 3), - ]) + @pytest.mark.parametrize( + "param_to_test", + [ + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ], + ) def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): """ Test that skipping KS4 preprocessing works as expected. Run @@ -498,8 +546,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): pass return X - monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", - monkeypatch_filter_function) + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function) ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) ks_settings["nblocks"] = 0 @@ -516,6 +563,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): ) monkeypatch.undo() + si.read_kilosort(kilosort_output_dir) # Now, run kilosort through spikeinterface with the same options. spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) @@ -537,33 +585,17 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): # memory file. Because in this test recordings are preprocessed, there are # some filter edge effects that depend on the chunking in `get_traces()`. # These are all extremely close (usually just 1 spike, 1 idx different). - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + results = {} + results["ks"] = {} + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + results["si"] = {} + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) - # Helpers ###### - def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): - """ - If nothing is changed, default vs. results outputs are identical. - Otherwise, check they are not the same. Can't figure out how to get - the skipped three parameters below to change the results on this - small test file. - """ - if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: - return - - if param_key == "change_nothing": - assert all( - default_results["ks"]["st"] == results["ks"]["st"] - ) and all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} changed somehow!." - else: - assert not ( - default_results["ks"]["st"].size == results["ks"]["st"].size - ) or not all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} results did not change with parameter change." - + ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ Function to generate the settings and function inputs to run kilosort. @@ -578,16 +610,17 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } + run_kilosort_kwargs = {} - if param_key == "binning_depth": - settings.update({"nblocks": 5}) + if param_key is not None: + if param_key == "binning_depth": + settings.update({"nblocks": 5}) - if param_key in RUN_KILOSORT_ARGS: - run_kilosort_kwargs = {param_key: param_value} - else: - if param_key != "change_nothing": + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: settings.update({param_key: param_value}) - run_kilosort_kwargs = {} + run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) @@ -598,33 +631,14 @@ def _get_spikeinterface_settings(self, param_key, param_value): Generate settings kwargs for running KS4 in SpikeInterface. See `_get_kilosort_native_settings()` for some details. """ - settings = {} # copy.deepcopy(DEFAULT_SETTINGS) - - if param_key != "change_nothing": - settings.update({param_key: param_value}) + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) if param_key == "binning_depth": settings.update({"nblocks": 5}) - # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + settings.update({param_key: param_value}) + + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # settings.pop(name) return settings - - def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ - Load the results of sorting into a dict for easy comparison. - """ - results = { - "si": {}, - "ks": {}, - } - if kilosort_output_dir: - results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") - results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") - - if spikeinterface_output_dir: - results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") - results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") - - return results diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 2a50c976a5..e12cf6805d 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -50,7 +50,7 @@ jobs: shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" - python .github/determine_testing_environment.py $changed_files + python .github/scripts/determine_testing_environment.py $changed_files - name: Display testing environment shell: bash diff --git a/.github/workflows/core-test.yml b/.github/workflows/core-test.yml index a513d48f3b..1dbf0f5109 100644 --- a/.github/workflows/core-test.yml +++ b/.github/workflows/core-test.yml @@ -39,7 +39,7 @@ jobs: pip install tabulate echo "# Timing profile of core tests in ${{matrix.os}}" >> $GITHUB_STEP_SUMMARY # Outputs markdown summary to standard output - python ./.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report.txt shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index ab4a083ae1..6a222f5e25 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -47,7 +47,7 @@ jobs: source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY - python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report_full.txt - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_imports.yml b/.github/workflows/test_imports.yml index d39fc37242..a2631f6eb7 100644 --- a/.github/workflows/test_imports.yml +++ b/.github/workflows/test_imports.yml @@ -34,7 +34,7 @@ jobs: echo "## OS: ${{ matrix.os }}" >> $GITHUB_STEP_SUMMARY echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when only installing only core dependencies " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows - name: Install in full mode run: | @@ -44,5 +44,5 @@ jobs: # Add a header to separate the two profiles echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when installing full dependencies in " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 390bec98be..42e6140917 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -4,26 +4,30 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + pull_request: + paths: + - '**/kilosort4.py' jobs: versions: - # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # Poll Pypi for all released KS4 versions >4.0.16, save to JSON # and store them in a matrix for the next job. runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.12 - name: Install dependencies run: | pip install requests packaging + pip install . - name: Fetch package versions from PyPI run: | @@ -47,7 +51,7 @@ jobs: ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v5 diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8499cef11f..e73ac2cb6c 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -4,8 +4,11 @@ from typing import Union from packaging import version -from ..basesorter import BaseSorter + +from ...core import write_binary_recording +from ..basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase +from ..basesorter import get_job_kwargs from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -17,6 +20,7 @@ class Kilosort4Sorter(BaseSorter): sorter_name: str = "kilosort4" requires_locations = True gpu_capability = "nvidia-optional" + requires_binary_data = False _default_params = { "batch_size": 60000, @@ -31,6 +35,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": None, "nskip": 25, "whitening_range": 32, + "highpass_cutoff": 300, "binning_depth": 5, "sig_interp": 20, "drift_smoothing": [0.5, 0.5, 0.5], @@ -51,14 +56,18 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7, - "do_correction": True, - "keep_good_only": False, - "save_extra_kwargs": False, - "skip_kilosort_preprocessing": False, + "duplicate_spike_ms": 0.25, "scaleproc": None, "save_preprocessed_copy": False, "torch_device": "auto", + "bad_channels": None, + "clear_cache": False, + "save_extra_vars": False, + "do_correction": True, + "keep_good_only": False, + "skip_kilosort_preprocessing": False, + "use_binary_file": None, + "delete_recording_dat": True, } _params_description = { @@ -74,6 +83,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", + "highpass_cutoff": "High-pass filter cutoff frequency in Hz. Default value: 300.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", @@ -95,12 +105,19 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "do_correction": "If True, drift correction is performed", - "save_extra_kwargs": "If True, additional kwargs are saved to the output", - "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", + "save_extra_vars": "If True, additional kwargs are saved to the output", "scaleproc": "int16 scaling of whitened data, if None set to 200.", - "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", + "save_preprocessed_copy": "Save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", + "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", + "clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.", + "do_correction": "If True, drift correction is performed. Default is True. (spikeinterface parameter)", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing. (spikeinterface parameter)", + "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", + "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " + "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " + "Default is None. (spikeinterface parameter)", + "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -110,7 +127,7 @@ class Kilosort4Sorter(BaseSorter): For more information see https://github.com/MouseLand/Kilosort""" installation_mesg = """\nTo use Kilosort4 run:\n - >>> pip install kilosort==4.0 + >>> pip install kilosort --upgrade More information on Kilosort4 at: https://github.com/MouseLand/Kilosort @@ -134,6 +151,25 @@ def get_sorter_version(cls): """kilosort.__version__ <4.0.10 is always '4'""" return importlib_version("kilosort") + @classmethod + def initialize_folder(cls, recording, output_folder, verbose, remove_existing_folder): + if not cls.is_installed(): + raise Exception( + f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}" + ) + cls.check_sorter_version() + return super(Kilosort4Sorter, cls).initialize_folder(recording, output_folder, verbose, remove_existing_folder) + + @classmethod + def check_sorter_version(cls): + kilosort_version = version.parse(cls.get_sorter_version()) + if kilosort_version < version.parse("4.0.16"): + raise Exception( + f"""SpikeInterface only supports kilosort versions 4.0.16 and above. You are running version {kilosort_version}. To install the latest version, run: + >>> pip install kilosort --upgrade + """ + ) + @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): from probeinterface import write_prb @@ -142,6 +178,17 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" write_prb(probe_filename, pg) + if params["use_binary_file"]: + if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # local copy needed + binary_file_path = sorter_output_folder / "recording.dat" + write_binary_recording( + recording=recording, + file_paths=[binary_file_path], + **get_job_kwargs(params, verbose), + ) + params["filename"] = str(binary_file_path) + @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): from kilosort.run_kilosort import ( @@ -186,14 +233,39 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) probe = load_probe(probe_path=probe_filename) probe_name = "" - filename = "" - # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording_extractor=recording) + if params["use_binary_file"] is None: + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # the recording is not binary compatible and no binary copy has been written. + # in this case, we use the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) + elif params["use_binary_file"]: + # here we force the use of a binary file + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None + else: + # here we force the use of the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) + + data_dtype = recording.get_dtype() do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] - save_extra_vars = params["save_extra_kwargs"] + save_extra_vars = params["save_extra_vars"] save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} @@ -214,6 +286,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder + bad_channels = params["bad_channels"] + clear_cache = params["clear_cache"] filename, data_dir, results_dir, probe = set_files( settings=settings, @@ -222,36 +296,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name=probe_name, data_dir=data_dir, results_dir=results_dir, + bad_channels=bad_channels, ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops( - settings=settings, - probe=probe, - data_dtype=recording.get_dtype(), - do_CAR=do_CAR, - invert_sign=invert_sign, - device=device, - save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) - ) - else: - ops = initialize_ops( - settings=settings, - probe=probe, - data_dtype=recording.get_dtype(), - do_CAR=do_CAR, - invert_sign=invert_sign, - device=device, - ) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=data_dtype, + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocessed_copy=save_preprocessed_copy, + ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) - else: - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: @@ -291,17 +351,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, + device=device, + tic0=tic0, + progress_bar=progress_bar, + file_object=file_object, + clear_cache=clear_cache, ) if save_preprocessed_copy: save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes( + ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache + ) clu, Wall = cluster_spikes( - st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + st=st, + tF=tF, + ops=ops, + device=device, + bfile=bfile, + tic0=tic0, + progress_bar=progress_bar, + clear_cache=clear_cache, ) if params["skip_kilosort_preprocessing"]: @@ -309,21 +383,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - _ = save_sorting( - ops=ops, - results_dir=results_dir, - st=st, - clu=clu, - tF=tF, - Wall=Wall, - imin=bfile.imin, - tic0=tic0, - save_extra_vars=save_extra_vars, - save_preprocessed_copy=save_preprocessed_copy, - ) - else: - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + + if params["delete_recording_dat"]: + # only delete dat file if it was created by the wrapper + if (sorter_output_folder / "recording.dat").is_file(): + (sorter_output_folder / "recording.dat").unlink() @classmethod def _get_result_from_folder(cls, sorter_output_folder):