Skip to content

Commit

Permalink
Try a different file location just to get things working.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jun 26, 2024
1 parent 6f9dfc4 commit 2814de9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 103 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test_kilosort4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ jobs:
# run: chmod +x .github/test_kilosort4.sh
# TODO: must be a better way
run: |
cd ..
pytest .github/test_kilosort4.py --durations=0
pytest temp_test_file_dir/test_kilosort4.py --durations=0
shell: bash

# TODO: pip install -e .[full,dev] is failing #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,45 @@


PARAMS_TO_TEST = [

# Not tested
# ("torch_device", "auto")

# Stable across KS version 4.0.01 - 4.0.12
("change_nothing", None),
("nblocks", 0),
("do_CAR", False),
("batch_size", 42743), # Q: how much do these results change with batch size?
("Th_universal", 12),
("Th_learned", 14),
("invert_sign", True),
("nt", 93),
("nskip", 1),
("whitening_range", 16),
("sig_interp", 5),
("nt0min", 25),
("dmin", 15),
("dminx", 16),
("min_template_size", 15),
("template_sizes", 10),
("nearest_chans", 8),
("nearest_templates", 35),
("max_channel_distance", 5),
("templates_from_data", False),
("n_templates", 10),
("n_pcs", 3),
("Th_single_ch", 4),
("acg_threshold", 0.001),
("x_centers", 5),
("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS
("binning_depth", 1),
("artifact_threshold", 200),
("ccg_threshold", 1e9),
("cluster_downsampling", 1e9),
("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13!
# Not tested
# ("torch_device", "auto")
# Stable across KS version 4.0.01 - 4.0.12
("change_nothing", None),
("nblocks", 0),
("do_CAR", False),
("batch_size", 42743), # Q: how much do these results change with batch size?
("Th_universal", 12),
("Th_learned", 14),
("invert_sign", True),
("nt", 93),
("nskip", 1),
("whitening_range", 16),
("sig_interp", 5),
("nt0min", 25),
("dmin", 15),
("dminx", 16),
("min_template_size", 15),
("template_sizes", 10),
("nearest_chans", 8),
("nearest_templates", 35),
("max_channel_distance", 5),
("templates_from_data", False),
("n_templates", 10),
("n_pcs", 3),
("Th_single_ch", 4),
("acg_threshold", 0.001),
("x_centers", 5),
("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS
("binning_depth", 1),
("artifact_threshold", 200),
("ccg_threshold", 1e9),
("cluster_downsampling", 1e9),
("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13!
]

# Update PARAMS_TO_TEST with version-dependent kwargs
if parse(version("kilosort")) >= parse("4.0.12"):
pass # TODO: expose?
pass # TODO: expose?
# PARAMS_TO_TEST.extend(
# [
# ("save_preprocessed_copy", False),
Expand Down Expand Up @@ -124,8 +122,7 @@ class TestKilosort4Long:
# Fixtures ######
@pytest.fixture(scope="session")
def recording_and_paths(self, tmp_path_factory):
"""
"""
""" """
tmp_path = tmp_path_factory.mktemp("kilosort4_tests")

np.random.seed(0) # TODO: check below...
Expand All @@ -138,13 +135,10 @@ def recording_and_paths(self, tmp_path_factory):

@pytest.fixture(scope="session")
def default_results(self, recording_and_paths):
"""
"""
""" """
recording, paths = recording_and_paths

settings, ks_format_probe = self._run_kilosort_with_kilosort(
recording, paths
)
settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths)

defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output"

Expand Down Expand Up @@ -195,12 +189,9 @@ def test_default_settings_all_represented(self):
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."

@pytest.mark.parametrize("parameter",
PARAMS_TO_TEST
)
@pytest.mark.parametrize("parameter", PARAMS_TO_TEST)
def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter):
"""
"""
""" """
recording, paths = recording_and_paths
param_key, param_value = parameter

Expand All @@ -218,9 +209,7 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet
extra_ks_settings.update({param_key: param_value})
run_kilosort_kwargs = {}

settings, ks_format_probe = self._run_kilosort_with_kilosort(
recording, paths, extra_ks_settings
)
settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings)

kilosort.run_kilosort(
settings=settings,
Expand All @@ -238,16 +227,24 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet
extra_si_settings.update({"nblocks": 5})

spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings)
si.run_sorter("kilosort4", recording, remove_existing_folder=True,
folder=spikeinterface_output_dir, **spikeinterface_settings)
si.run_sorter(
"kilosort4",
recording,
remove_existing_folder=True,
folder=spikeinterface_output_dir,
**spikeinterface_settings,
)

results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir)

assert np.array_equal(results["ks"]["st"],
results["si"]["st"]), f"{param_key} spike times different"
assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different"

assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]), f"{param_key} cluster assignment different"
assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]), f"{param_key} cluster quality different" # TODO: check pandas probably better way
assert all(
results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]
), f"{param_key} cluster assignment different"
assert all(
results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]
), f"{param_key} cluster quality different" # TODO: check pandas probably better way

# This is saved on the SI side so not an extremely
# robust addition, but it can't hurt.
Expand All @@ -258,21 +255,16 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet

# Finally, check out test parameters actually changes stuff!
if parse(version("kilosort")) > parse("4.0.4"):
self._check_test_parameters_are_actually_changing_the_output(
results, default_results, param_key
)
self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key)

def test_kilosort4_no_correction(self, recording_and_paths, tmp_path):
"""
"""
""" """
recording, paths = recording_and_paths

kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"

settings, ks_format_probe = self._run_kilosort_with_kilosort(
recording, paths, extra_settings={"nblocks": 0}
)
settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0})

kilosort.run_kilosort(
settings=settings,
Expand All @@ -283,8 +275,14 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path):
)

spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6})
si.run_sorter("kilosort4", recording, remove_existing_folder=True,
folder=spikeinterface_output_dir, do_correction=False, **spikeinterface_settings)
si.run_sorter(
"kilosort4",
recording,
remove_existing_folder=True,
folder=spikeinterface_output_dir,
do_correction=False,
**spikeinterface_settings,
)

results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir)

Expand All @@ -294,26 +292,21 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path):
assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1])

def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch):
"""
"""
""" """
recording = self._get_ground_truth_recording()

# We need to filter and whiten the recording here to KS takes forever.
# Do this in a way differnt to KS.
recording = si.highpass_filter(recording, 300)
recording = si.whiten(recording, mode="local", apply_mean=False
)
recording = si.whiten(recording, mode="local", apply_mean=False)

paths = self._save_ground_truth_recording(recording, tmp_path)

kilosort_default_output_dir = tmp_path / "kilosort_default_output_dir"
kilosort_output_dir = tmp_path / "kilosort_output_dir"
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"


ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(
recording, paths, extra_settings={"nblocks": 0}
)
ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0})

kilosort.run_kilosort(
settings=ks_settings,
Expand Down Expand Up @@ -365,8 +358,15 @@ def fake_fftshift(X, dim):

# Now, run kilosort through spikeinterface with the same options.
spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0})
si.run_sorter("kilosort4", recording, remove_existing_folder=True,
folder=spikeinterface_output_dir, do_CAR=False, skip_kilosort_preprocessing=True, **spikeinterface_settings)
si.run_sorter(
"kilosort4",
recording,
remove_existing_folder=True,
folder=spikeinterface_output_dir,
do_CAR=False,
skip_kilosort_preprocessing=True,
**spikeinterface_settings,
)

default_results = self._get_sorting_output(kilosort_default_output_dir)
results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir)
Expand All @@ -379,38 +379,43 @@ def fake_fftshift(X, dim):

# Helpers ######
def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key):
"""
"""
""" """
if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]:
num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size
num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size

if param_key == "change_nothing":
# TODO: lol
assert (results["si"]["st"].size == default_results["ks"]["st"].size) and num_clus == num_clus_default and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]), f"{param_key} changed somehow!."
assert (
(results["si"]["st"].size == default_results["ks"]["st"].size)
and num_clus == num_clus_default
and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1])
), f"{param_key} changed somehow!."
else:
assert (results["si"]["st"].size != default_results["ks"]["st"].size) or num_clus != num_clus_default or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]), f"{param_key} results did not change with parameter change."
assert (
(results["si"]["st"].size != default_results["ks"]["st"].size)
or num_clus != num_clus_default
or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1])
), f"{param_key} results did not change with parameter change."

def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None):
"""
"""
""" """
# dont actually run KS here because we will overwrite the defaults!
settings = {'data_dir': paths["recording_path"],
'n_chan_bin': recording.get_num_channels(),
"fs": recording.get_sampling_frequency()}
settings = {
"data_dir": paths["recording_path"],
"n_chan_bin": recording.get_num_channels(),
"fs": recording.get_sampling_frequency(),
}

if extra_settings is not None:
settings.update(extra_settings)

ks_format_probe = load_probe(
paths["probe_path"]
)
ks_format_probe = load_probe(paths["probe_path"])

return settings, ks_format_probe

def _get_spikeinterface_settings(self, extra_settings=None):
"""
"""
""" """
# dont actually run here.
settings = copy.deepcopy(DEFAULT_SETTINGS)

Expand All @@ -423,8 +428,7 @@ def _get_spikeinterface_settings(self, extra_settings=None):
return settings

def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]:
"""
"""
""" """
results = {
"si": {},
"ks": {},
Expand All @@ -440,27 +444,24 @@ def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_di
return results

def _get_ground_truth_recording(self):
"""
"""
""" """
# Chosen so all parameter changes to indeed change the output
num_channels = 32
recording, _ = si.generate_ground_truth_recording(
durations=[5],
seed=0,
num_channels=num_channels,
num_units=5,
generate_sorting_kwargs=dict(firing_rates=100,
refractory_period_ms=4.0),
generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0),
)
return recording

def _save_ground_truth_recording(self, recording, tmp_path):
"""
"""
""" """
paths = {
"session_scope_tmp_path": tmp_path,
"recording_path": tmp_path / "my_test_recording",
"probe_path": tmp_path / "my_test_probe.prb"
"probe_path": tmp_path / "my_test_probe.prb",
}

recording.save(folder=paths["recording_path"], overwrite=True)
Expand Down

0 comments on commit 2814de9

Please sign in to comment.