diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 2dba9d1590..6674280c65 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -52,8 +52,7 @@ jobs: # run: chmod +x .github/test_kilosort4.sh # TODO: must be a better way run: | - cd .. - pytest .github/test_kilosort4.py --durations=0 + pytest temp_test_file_dir/test_kilosort4.py --durations=0 shell: bash # TODO: pip install -e .[full,dev] is failing # diff --git a/.github/test_kilosort4.py b/src/spikeinterface/temp_test_file_dir/test_kilosort4.py similarity index 78% rename from .github/test_kilosort4.py rename to src/spikeinterface/temp_test_file_dir/test_kilosort4.py index 95090ef2a8..0fb9841728 100644 --- a/.github/test_kilosort4.py +++ b/src/spikeinterface/temp_test_file_dir/test_kilosort4.py @@ -45,47 +45,45 @@ PARAMS_TO_TEST = [ - - # Not tested - # ("torch_device", "auto") - - # Stable across KS version 4.0.01 - 4.0.12 - ("change_nothing", None), - ("nblocks", 0), - ("do_CAR", False), - ("batch_size", 42743), # Q: how much do these results change with batch size? - ("Th_universal", 12), - ("Th_learned", 14), - ("invert_sign", True), - ("nt", 93), - ("nskip", 1), - ("whitening_range", 16), - ("sig_interp", 5), - ("nt0min", 25), - ("dmin", 15), - ("dminx", 16), - ("min_template_size", 15), - ("template_sizes", 10), - ("nearest_chans", 8), - ("nearest_templates", 35), - ("max_channel_distance", 5), - ("templates_from_data", False), - ("n_templates", 10), - ("n_pcs", 3), - ("Th_single_ch", 4), - ("acg_threshold", 0.001), - ("x_centers", 5), - ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS - ("binning_depth", 1), - ("artifact_threshold", 200), - ("ccg_threshold", 1e9), - ("cluster_downsampling", 1e9), - ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! + # Not tested + # ("torch_device", "auto") + # Stable across KS version 4.0.01 - 4.0.12 + ("change_nothing", None), + ("nblocks", 0), + ("do_CAR", False), + ("batch_size", 42743), # Q: how much do these results change with batch size? + ("Th_universal", 12), + ("Th_learned", 14), + ("invert_sign", True), + ("nt", 93), + ("nskip", 1), + ("whitening_range", 16), + ("sig_interp", 5), + ("nt0min", 25), + ("dmin", 15), + ("dminx", 16), + ("min_template_size", 15), + ("template_sizes", 10), + ("nearest_chans", 8), + ("nearest_templates", 35), + ("max_channel_distance", 5), + ("templates_from_data", False), + ("n_templates", 10), + ("n_pcs", 3), + ("Th_single_ch", 4), + ("acg_threshold", 0.001), + ("x_centers", 5), + ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS + ("binning_depth", 1), + ("artifact_threshold", 200), + ("ccg_threshold", 1e9), + ("cluster_downsampling", 1e9), + ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! ] # Update PARAMS_TO_TEST with version-dependent kwargs if parse(version("kilosort")) >= parse("4.0.12"): - pass # TODO: expose? + pass # TODO: expose? # PARAMS_TO_TEST.extend( # [ # ("save_preprocessed_copy", False), @@ -124,8 +122,7 @@ class TestKilosort4Long: # Fixtures ###### @pytest.fixture(scope="session") def recording_and_paths(self, tmp_path_factory): - """ - """ + """ """ tmp_path = tmp_path_factory.mktemp("kilosort4_tests") np.random.seed(0) # TODO: check below... @@ -138,13 +135,10 @@ def recording_and_paths(self, tmp_path_factory): @pytest.fixture(scope="session") def default_results(self, recording_and_paths): - """ - """ + """ """ recording, paths = recording_and_paths - settings, ks_format_probe = self._run_kilosort_with_kilosort( - recording, paths - ) + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -195,12 +189,9 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." - @pytest.mark.parametrize("parameter", - PARAMS_TO_TEST - ) + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): - """ - """ + """ """ recording, paths = recording_and_paths param_key, param_value = parameter @@ -218,9 +209,7 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet extra_ks_settings.update({param_key: param_value}) run_kilosort_kwargs = {} - settings, ks_format_probe = self._run_kilosort_with_kilosort( - recording, paths, extra_ks_settings - ) + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings) kilosort.run_kilosort( settings=settings, @@ -238,16 +227,24 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet extra_si_settings.update({"nblocks": 5}) spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings) - si.run_sorter("kilosort4", recording, remove_existing_folder=True, - folder=spikeinterface_output_dir, **spikeinterface_settings) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - assert np.array_equal(results["ks"]["st"], - results["si"]["st"]), f"{param_key} spike times different" + assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" - assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]), f"{param_key} cluster assignment different" - assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]), f"{param_key} cluster quality different" # TODO: check pandas probably better way + assert all( + results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0] + ), f"{param_key} cluster assignment different" + assert all( + results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1] + ), f"{param_key} cluster quality different" # TODO: check pandas probably better way # This is saved on the SI side so not an extremely # robust addition, but it can't hurt. @@ -258,21 +255,16 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet # Finally, check out test parameters actually changes stuff! if parse(version("kilosort")) > parse("4.0.4"): - self._check_test_parameters_are_actually_changing_the_output( - results, default_results, param_key - ) + self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key) def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): - """ - """ + """ """ recording, paths = recording_and_paths kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - settings, ks_format_probe = self._run_kilosort_with_kilosort( - recording, paths, extra_settings={"nblocks": 0} - ) + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) kilosort.run_kilosort( settings=settings, @@ -283,8 +275,14 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): ) spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6}) - si.run_sorter("kilosort4", recording, remove_existing_folder=True, - folder=spikeinterface_output_dir, do_correction=False, **spikeinterface_settings) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) @@ -294,15 +292,13 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]) def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): - """ - """ + """ """ recording = self._get_ground_truth_recording() # We need to filter and whiten the recording here to KS takes forever. # Do this in a way differnt to KS. recording = si.highpass_filter(recording, 300) - recording = si.whiten(recording, mode="local", apply_mean=False - ) + recording = si.whiten(recording, mode="local", apply_mean=False) paths = self._save_ground_truth_recording(recording, tmp_path) @@ -310,10 +306,7 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - - ks_settings, ks_format_probe = self._run_kilosort_with_kilosort( - recording, paths, extra_settings={"nblocks": 0} - ) + ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) kilosort.run_kilosort( settings=ks_settings, @@ -365,8 +358,15 @@ def fake_fftshift(X, dim): # Now, run kilosort through spikeinterface with the same options. spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0}) - si.run_sorter("kilosort4", recording, remove_existing_folder=True, - folder=spikeinterface_output_dir, do_CAR=False, skip_kilosort_preprocessing=True, **spikeinterface_settings) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=False, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) default_results = self._get_sorting_output(kilosort_default_output_dir) results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) @@ -379,38 +379,43 @@ def fake_fftshift(X, dim): # Helpers ###### def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): - """ - """ + """ """ if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size if param_key == "change_nothing": # TODO: lol - assert (results["si"]["st"].size == default_results["ks"]["st"].size) and num_clus == num_clus_default and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]), f"{param_key} changed somehow!." + assert ( + (results["si"]["st"].size == default_results["ks"]["st"].size) + and num_clus == num_clus_default + and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} changed somehow!." else: - assert (results["si"]["st"].size != default_results["ks"]["st"].size) or num_clus != num_clus_default or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]), f"{param_key} results did not change with parameter change." + assert ( + (results["si"]["st"].size != default_results["ks"]["st"].size) + or num_clus != num_clus_default + or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} results did not change with parameter change." def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None): - """ - """ + """ """ # dont actually run KS here because we will overwrite the defaults! - settings = {'data_dir': paths["recording_path"], - 'n_chan_bin': recording.get_num_channels(), - "fs": recording.get_sampling_frequency()} + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } if extra_settings is not None: settings.update(extra_settings) - ks_format_probe = load_probe( - paths["probe_path"] - ) + ks_format_probe = load_probe(paths["probe_path"]) return settings, ks_format_probe def _get_spikeinterface_settings(self, extra_settings=None): - """ - """ + """ """ # dont actually run here. settings = copy.deepcopy(DEFAULT_SETTINGS) @@ -423,8 +428,7 @@ def _get_spikeinterface_settings(self, extra_settings=None): return settings def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ - """ + """ """ results = { "si": {}, "ks": {}, @@ -440,8 +444,7 @@ def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_di return results def _get_ground_truth_recording(self): - """ - """ + """ """ # Chosen so all parameter changes to indeed change the output num_channels = 32 recording, _ = si.generate_ground_truth_recording( @@ -449,18 +452,16 @@ def _get_ground_truth_recording(self): seed=0, num_channels=num_channels, num_units=5, - generate_sorting_kwargs=dict(firing_rates=100, - refractory_period_ms=4.0), + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), ) return recording def _save_ground_truth_recording(self, recording, tmp_path): - """ - """ + """ """ paths = { "session_scope_tmp_path": tmp_path, "recording_path": tmp_path / "my_test_recording", - "probe_path": tmp_path / "my_test_probe.prb" + "probe_path": tmp_path / "my_test_probe.prb", } recording.save(folder=paths["recording_path"], overwrite=True)