Skip to content

Commit

Permalink
Remove last change_nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 3, 2024
1 parent 87fbe55 commit 10b7e1a
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Do some tests to check all KS4 parameters are tested against.
"""

import pytest
import copy
from typing import Any
Expand Down Expand Up @@ -91,7 +92,7 @@
"acg_threshold",
"cluster_downsampling",
"cluster_pcs",
"duplicate_spike_ms" # this is because gorund-truth spikes don't have violations
"duplicate_spike_ms", # this is because gorund-truth spikes don't have violations
]

# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST
Expand Down Expand Up @@ -145,7 +146,6 @@ def default_kilosort_sorting(self, recording_and_paths):

return si.read_kilosort(defaults_ks_output_dir)


def _get_ground_truth_recording(self):
"""
A ground truth recording chosen to be as small as possible (for speed).
Expand Down Expand Up @@ -225,7 +225,9 @@ def test_spikeinterface_defaults_against_kilsort(self):

# Testing Arguments ###
def test_set_files_arguments(self):
self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"])
self._check_arguments(
set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
)

def test_initialize_ops_arguments(self):
expected_arguments = [
Expand All @@ -247,13 +249,17 @@ def test_compute_preprocessing_arguments(self):
self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"])

def test_compute_drift_location_arguments(self):
self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"])
self._check_arguments(
compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]
)

def test_detect_spikes_arguments(self):
self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])

def test_cluster_spikes_arguments(self):
self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])
self._check_arguments(
cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
)

def test_save_sorting_arguments(self):
expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"]
Expand Down Expand Up @@ -397,7 +403,6 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path):
)
check_sortings_equal(sorting_ks, sorting_si)


def test_use_binary_file(self, tmp_path):
"""
Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly
Expand All @@ -410,21 +415,21 @@ def test_use_binary_file(self, tmp_path):
sorting_ks4 = si.run_sorter(
"kilosort4",
recording,
folder = tmp_path / "spikeinterface_output_dir_wrapper",
folder=tmp_path / "spikeinterface_output_dir_wrapper",
use_binary_file=False,
remove_existing_folder=True,
)
sorting_ks4_bin = si.run_sorter(
"kilosort4",
recording_bin,
folder = tmp_path / "spikeinterface_output_dir_bin",
folder=tmp_path / "spikeinterface_output_dir_bin",
use_binary_file=False,
remove_existing_folder=True,
)
sorting_ks4_non_bin = si.run_sorter(
"kilosort4",
recording,
folder = tmp_path / "spikeinterface_output_dir_non_bin",
folder=tmp_path / "spikeinterface_output_dir_non_bin",
use_binary_file=True,
remove_existing_folder=True,
)
Expand Down Expand Up @@ -546,7 +551,6 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy")
assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1)


##### Helpers ######
def _get_kilosort_native_settings(self, recording, paths, param_key, param_value):
"""
Expand All @@ -571,8 +575,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
if param_key in RUN_KILOSORT_ARGS:
run_kilosort_kwargs = {param_key: param_value}
else:
if param_key != "change_nothing":
settings.update({param_key: param_value})
settings.update({param_key: param_value})
run_kilosort_kwargs = {}

ks_format_probe = load_probe(paths["probe_path"])
Expand Down

0 comments on commit 10b7e1a

Please sign in to comment.