diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 35946a5a56..dbd8135b9a 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -17,6 +17,7 @@ - Do some tests to check all KS4 parameters are tested against. """ + import pytest import copy from typing import Any @@ -91,7 +92,7 @@ "acg_threshold", "cluster_downsampling", "cluster_pcs", - "duplicate_spike_ms" # this is because gorund-truth spikes don't have violations + "duplicate_spike_ms", # this is because gorund-truth spikes don't have violations ] # THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST @@ -145,7 +146,6 @@ def default_kilosort_sorting(self, recording_and_paths): return si.read_kilosort(defaults_ks_output_dir) - def _get_ground_truth_recording(self): """ A ground truth recording chosen to be as small as possible (for speed). @@ -225,7 +225,9 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): - self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]) + self._check_arguments( + set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] + ) def test_initialize_ops_arguments(self): expected_arguments = [ @@ -247,13 +249,17 @@ def test_compute_preprocessing_arguments(self): self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): - self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]) + self._check_arguments( + compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] + ) def test_detect_spikes_arguments(self): self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_cluster_spikes_arguments(self): - self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) + self._check_arguments( + cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] + ) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] @@ -397,7 +403,6 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): ) check_sortings_equal(sorting_ks, sorting_si) - def test_use_binary_file(self, tmp_path): """ Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly @@ -410,21 +415,21 @@ def test_use_binary_file(self, tmp_path): sorting_ks4 = si.run_sorter( "kilosort4", recording, - folder = tmp_path / "spikeinterface_output_dir_wrapper", + folder=tmp_path / "spikeinterface_output_dir_wrapper", use_binary_file=False, remove_existing_folder=True, ) sorting_ks4_bin = si.run_sorter( "kilosort4", recording_bin, - folder = tmp_path / "spikeinterface_output_dir_bin", + folder=tmp_path / "spikeinterface_output_dir_bin", use_binary_file=False, remove_existing_folder=True, ) sorting_ks4_non_bin = si.run_sorter( "kilosort4", recording, - folder = tmp_path / "spikeinterface_output_dir_non_bin", + folder=tmp_path / "spikeinterface_output_dir_non_bin", use_binary_file=True, remove_existing_folder=True, ) @@ -546,7 +551,6 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) - ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ @@ -571,8 +575,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value if param_key in RUN_KILOSORT_ARGS: run_kilosort_kwargs = {param_key: param_value} else: - if param_key != "change_nothing": - settings.update({param_key: param_value}) + settings.update({param_key: param_value}) run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"])