diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 0593534010..35946a5a56 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -10,31 +10,27 @@ changes when skipping KS4 preprocessing is true, because this takes a slightly different path through the kilosort4.py wrapper logic. This also checks that changing the parameter changes the test output from default - on our test case (otherwise, the test could not detect a failure). This is possible - for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + on our test case (otherwise, the test could not detect a failure). - Test that kilosort functions called from `kilosort4.py` wrapper have the expected input signatures - Do some tests to check all KS4 parameters are tested against. """ - +import pytest import copy from typing import Any +from inspect import signature + import numpy as np import torch -import kilosort -from kilosort.io import load_probe -import pandas as pd -import pytest -from packaging.version import parse -from importlib.metadata import version -from inspect import signature import spikeinterface.full as si +from spikeinterface.core.testing import check_sortings_equal from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter from probeinterface.io import write_prb +import kilosort from kilosort.parameters import DEFAULT_SETTINGS from kilosort.run_kilosort import ( set_files, @@ -47,59 +43,62 @@ get_run_parameters, ) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered -from kilosort.parameters import DEFAULT_SETTINGS -from kilosort import preprocessing as ks_preprocessing + RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. # Setup Params to test #### -PARAMS_TO_TEST = [ - # Not tested - # ("torch_device", "auto") - # Stable across KS version 4.0.16 - 4.0.X (?) - ("change_nothing", None), - ("nblocks", 0), - ("do_CAR", False), - ("batch_size", 42743), - ("Th_universal", 12), - ("Th_learned", 14), - ("invert_sign", True), - ("nt", 93), - ("nskip", 1), - ("whitening_range", 16), - ("highpass_cutoff", 200), - ("sig_interp", 5), - ("nt0min", 25), - ("dmin", 15), - ("dminx", 16), - ("min_template_size", 15), - ("template_sizes", 10), - ("nearest_chans", 8), - ("nearest_templates", 35), - ("max_channel_distance", 5), - ("templates_from_data", False), - ("n_templates", 10), - ("n_pcs", 3), - ("Th_single_ch", 4), - ("x_centers", 5), - ("binning_depth", 1), - # Note: These don't change the results from - # default when applied to the test case. - ("artifact_threshold", 200), - ("ccg_threshold", 1e12), - ("acg_threshold", 1e12), - ("cluster_downsampling", 2), - ("drift_smoothing", [250, 250, 250]), - # Not tested beacuse with ground truth data it doesn't change the results - # ("duplicate_spike_ms", 0.3), +PARAMS_TO_TEST_DICT = { + "nblocks": 0, + "do_CAR": False, + "batch_size": 42743, + "Th_universal": 12, + "Th_learned": 14, + "invert_sign": True, + "nt": 93, + "nskip": 1, + "whitening_range": 16, + "highpass_cutoff": 200, + "sig_interp": 5, + "nt0min": 25, + "dmin": 15, + "dminx": 16, + "min_template_size": 15, + "template_sizes": 10, + "nearest_chans": 8, + "nearest_templates": 35, + "max_channel_distance": 5, + "templates_from_data": False, + "n_templates": 10, + "n_pcs": 3, + "Th_single_ch": 4, + "x_centers": 5, + "binning_depth": 1, + "drift_smoothing": [250, 250, 250], + "artifact_threshold": 200, + "ccg_threshold": 1e12, + "acg_threshold": 1e12, + "cluster_downsampling": 2, + "duplicate_spike_ms": 0.3, +} + +PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) + +PARAMETERS_NOT_AFFECTING_RESULTS = [ + "artifact_threshold", + "ccg_threshold", + "acg_threshold", + "cluster_downsampling", + "cluster_pcs", + "duplicate_spike_ms" # this is because gorund-truth spikes don't have violations ] - +# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST # if parse(version("kilosort")) >= parse("4.0.X"): -# PARAMS_TO_TEST.extend( +# PARAMS_TO_TEST_DICT.update( # [ -# ("new_param", new_values), +# {"new_param": new_value}, # ] # ) @@ -122,7 +121,7 @@ def recording_and_paths(self, tmp_path_factory): return (recording, paths) @pytest.fixture(scope="session") - def default_results(self, recording_and_paths): + def default_kilosort_sorting(self, recording_and_paths): """ Because we check each parameter at a time and check the KS4 and SpikeInterface versions match, if changing the parameter @@ -133,7 +132,7 @@ def default_results(self, recording_and_paths): """ recording, paths = recording_and_paths - settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, None, None) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -144,9 +143,8 @@ def default_results(self, recording_and_paths): results_dir=defaults_ks_output_dir, ) - default_results = self._get_sorting_output(defaults_ks_output_dir) + return si.read_kilosort(defaults_ks_output_dir) - return default_results def _get_ground_truth_recording(self): """ @@ -185,16 +183,11 @@ def _save_ground_truth_recording(self, recording, tmp_path): # Tests ###### def test_params_to_test(self): """ - Test that all values in PARAMS_TO_TEST are + Test that all values in PARAMS_TO_TEST_DICT are different to the default values used in Kilosort, otherwise there is no point to the test. """ - for parameter in PARAMS_TO_TEST: - param_key, param_value = parameter - - if param_key == "change_nothing": - continue - + for param_key, param_value in PARAMS_TO_TEST_DICT.items(): if param_key not in RUN_KILOSORT_ARGS: assert DEFAULT_SETTINGS[param_key] != param_value, ( f"{param_key} values should be different in test: " @@ -207,8 +200,8 @@ def test_default_settings_all_represented(self): PARAMS_TO_TEST, otherwise we are missing settings added on the KS side. """ - tested_keys = [entry[0] for entry in PARAMS_TO_TEST] - additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy", "duplicate_spike_ms"] + tested_keys = PARAMS_TO_TEST + additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy"] tested_keys += additional_non_tested_keys for param_key in DEFAULT_SETTINGS: @@ -315,7 +308,7 @@ def _check_arguments(self, object_, expected_arguments): # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) - def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): + def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp_path, parameter): """ Given a recording, paths to raw data, and a parameter to change, run KS4 natively and within the SpikeInterface wrapper with the @@ -323,7 +316,8 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa check the outputs are the same. """ recording, paths = recording_and_paths - param_key, param_value = parameter + param_key = parameter + param_value = PARAMS_TO_TEST_DICT[param_key] # Setup parameters for KS4 and run it natively kilosort_output_dir = tmp_path / "kilosort_output_dir" @@ -340,11 +334,12 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa results_dir=kilosort_output_dir, **run_kilosort_kwargs, ) + sorting_ks4 = si.read_kilosort(kilosort_output_dir) # Setup Parameters for SI and KS4 through SI spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -353,21 +348,19 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa ) # Get the results and check they match - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - - assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" + check_sortings_equal(sorting_ks4, sorting_si) # Check the ops file in KS4 output is as expected. This is saved on the # SI side so not an extremely robust addition, but it can't hurt. - if param_key != "change_nothing": - ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) - ops = ops.tolist() # strangely this makes a dict - assert ops[param_key] == param_value + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something. - self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + # KS4, ensuring our tests are actually doing something (exxcept for some params). + if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: + with pytest.raises(AssertionError): + check_sortings_equal(default_kilosort_sorting, sorting_si) def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ @@ -391,9 +384,10 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): results_dir=kilosort_output_dir, do_CAR=True, ) + sorting_ks = si.read_kilosort(kilosort_output_dir) spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -401,21 +395,46 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): do_correction=False, **spikeinterface_settings, ) + check_sortings_equal(sorting_ks, sorting_si) - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - assert np.array_equal(results["ks"]["st"], results["si"]["st"]) - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + def test_use_binary_file(self, tmp_path): + """ + Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly + from the recording. + """ + recording = self._get_ground_truth_recording() + recording_bin = recording.save() + # run with SI wrapper + sorting_ks4 = si.run_sorter( + "kilosort4", + recording, + folder = tmp_path / "spikeinterface_output_dir_wrapper", + use_binary_file=False, + remove_existing_folder=True, + ) + sorting_ks4_bin = si.run_sorter( + "kilosort4", + recording_bin, + folder = tmp_path / "spikeinterface_output_dir_bin", + use_binary_file=False, + remove_existing_folder=True, + ) + sorting_ks4_non_bin = si.run_sorter( + "kilosort4", + recording, + folder = tmp_path / "spikeinterface_output_dir_non_bin", + use_binary_file=True, + remove_existing_folder=True, + ) - def test_kilosort4_use_binary_file(self, recording_and_paths, tmp_path): - # TODO - pass + check_sortings_equal(sorting_ks4, sorting_ks4_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_non_bin) @pytest.mark.parametrize( "param_to_test", [ - ("change_nothing", None), ("do_CAR", False), ("batch_size", 42743), ("Th_learned", 14), @@ -496,6 +515,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): ) monkeypatch.undo() + si.read_kilosort(kilosort_output_dir) # Now, run kilosort through spikeinterface with the same options. spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) @@ -517,29 +537,17 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): # memory file. Because in this test recordings are preprocessed, there are # some filter edge effects that depend on the chunking in `get_traces()`. # These are all extremely close (usually just 1 spike, 1 idx different). - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + results = {} + results["ks"] = {} + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + results["si"] = {} + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) - # Helpers ###### - def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): - """ - If nothing is changed, default vs. results outputs are identical. - Otherwise, check they are not the same. Can't figure out how to get - the skipped three parameters below to change the results on this - small test file. - """ - if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: - return - - if param_key == "change_nothing": - assert all(default_results["ks"]["st"] == results["ks"]["st"]) and all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} changed somehow!." - else: - assert not (default_results["ks"]["st"].size == results["ks"]["st"].size) or not all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} results did not change with parameter change." + ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ Function to generate the settings and function inputs to run kilosort. @@ -554,16 +562,18 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } + run_kilosort_kwargs = {} - if param_key == "binning_depth": - settings.update({"nblocks": 5}) + if param_key is not None: + if param_key == "binning_depth": + settings.update({"nblocks": 5}) - if param_key in RUN_KILOSORT_ARGS: - run_kilosort_kwargs = {param_key: param_value} - else: - if param_key != "change_nothing": - settings.update({param_key: param_value}) - run_kilosort_kwargs = {} + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) @@ -576,31 +586,12 @@ def _get_spikeinterface_settings(self, param_key, param_value): """ settings = {} # copy.deepcopy(DEFAULT_SETTINGS) - if param_key != "change_nothing": - settings.update({param_key: param_value}) - if param_key == "binning_depth": settings.update({"nblocks": 5}) + settings.update({param_key: param_value}) + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # settings.pop(name) return settings - - def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ - Load the results of sorting into a dict for easy comparison. - """ - results = { - "si": {}, - "ks": {}, - } - if kilosort_output_dir: - results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") - results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") - - if spikeinterface_output_dir: - results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") - results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") - - return results