From f9ff667188d2b70026dcf2267c367a0e92a9ce4d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:03:06 +0100 Subject: [PATCH 01/30] Add for 'set_files'. --- src/spikeinterface/sorters/external/kilosort4.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..92bfabbe73 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -205,7 +205,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + + filename, data_dir, results_dir, probe = set_files( + settings=settings, + filename=filename, + probe=probe, + probe_name=probe_name, + data_dir=data_dir, + results_dir=results_dir, + ) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( From aaa389f78243ac5c40c89f78cf282e69b591aebe Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:06:30 +0100 Subject: [PATCH 02/30] Add for 'initialize_ops'. --- .../sorters/external/kilosort4.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 92bfabbe73..b723e7a2bb 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -216,12 +216,27 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocessed_copy=False, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) else: - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) From 3ea9b8da6c9c0ce3672fc2b60d551cbfa96f8552 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:07:08 +0100 Subject: [PATCH 03/30] Add for 'compute_preprocessing'. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index b723e7a2bb..d8b1f1a60a 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -243,7 +243,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( From 28656425eac96ef7be1573256c316c23b057f1c5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:07:43 +0100 Subject: [PATCH 04/30] Add for 'compute_drift_correction'. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index d8b1f1a60a..d187b445ef 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -278,7 +278,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) # Sort spikes and save results From 9e0207aed2f92424e5d8d8088ce6c95de286eb38 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:11:42 +0100 Subject: [PATCH 05/30] Add for detect_spikes, cluster_spikes, save_sorting. --- .../sorters/external/kilosort4.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index d187b445ef..032f980ee2 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -282,14 +282,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + + clu, Wall = cluster_spikes( + st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + ) + if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + ) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From b07359ff360a210ae864fbf43c23d805b9507300 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:13:46 +0100 Subject: [PATCH 06/30] Add for 'load_probe', 'RecordingExtractorAsArray'. --- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 032f980ee2..ba1b10b793 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -176,12 +176,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - probe = load_probe(probe_filename) + probe = load_probe(probe_path=probe_filename) probe_name = "" filename = "" # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording) + file_object = RecordingExtractorAsArray(recording_extractor=recording) do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] From ac844e9d550624f007f851c1cc061e5c36abb002 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:47:22 +0100 Subject: [PATCH 07/30] Add for BinaryFiltered + some generate notes. --- .../sorters/external/kilosort4.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index ba1b10b793..47ef328b28 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -17,6 +17,7 @@ class Kilosort4Sorter(BaseSorter): requires_locations = True gpu_capability = "nvidia-optional" + # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 _default_params = { "batch_size": 60000, "nblocks": 1, @@ -25,8 +26,8 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, - "shift": None, - "scale": None, + "shift": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 + "scale": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 "artifact_threshold": None, "nskip": 25, "whitening_range": 32, @@ -247,16 +248,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, + filename=ops["filename"], + n_chan_bin=n_chan_bin, + fs=fs, + nT=NT, + nt=nt, + nt0min=twav_min, + chan_map=chan_map, hp_filter=None, device=device, - do_CAR=do_CAR, + do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? invert_sign=invert, dtype=dtype, tmin=tmin, From 44835bb397a36ebfd914a6c2a8038bf3727b95e3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:50:12 +0100 Subject: [PATCH 08/30] Update note on DEFAULT_SETTINGS. --- src/spikeinterface/sorters/external/kilosort4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 47ef328b28..bcd8ddc617 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -18,6 +18,8 @@ class Kilosort4Sorter(BaseSorter): gpu_capability = "nvidia-optional" # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 + # I see these overwrite the `DEFAULT_SETTINGS`. Do we want to do this? There is benefit to fixing on the SI side, but users switching KS version would expect + # the defaults to represent the KS version. This could lead to divergence in result between users running KS directly vs. the SI wrapper. _default_params = { "batch_size": 60000, "nblocks": 1, From 5bdc31e1ac6f2b3ecde2f2d428f4bae306dacfb3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 24 Jun 2024 19:59:16 +0100 Subject: [PATCH 09/30] Remove some TODO and notes. --- src/spikeinterface/sorters/external/kilosort4.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index bcd8ddc617..cba7e65517 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -17,9 +17,6 @@ class Kilosort4Sorter(BaseSorter): requires_locations = True gpu_capability = "nvidia-optional" - # Q: Should we take these directly from the KS defaults? https://github.com/MouseLand/Kilosort/blob/59c03b060cc8e8ac75a7f1a972a8b5c5af3f41a6/kilosort/parameters.py#L164 - # I see these overwrite the `DEFAULT_SETTINGS`. Do we want to do this? There is benefit to fixing on the SI side, but users switching KS version would expect - # the defaults to represent the KS version. This could lead to divergence in result between users running KS directly vs. the SI wrapper. _default_params = { "batch_size": 60000, "nblocks": 1, @@ -28,8 +25,8 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, - "shift": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 - "scale": None, # TODO: I don't think these are passed to BinaryFiltered when preprocessing skipped. Need to distinguish version +/ 4.0.9 + "shift": None, + "scale": None, "artifact_threshold": None, "nskip": 25, "whitening_range": 32, From dc848eb2f8691206826d9545927d9cf28fbcd558 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 02:04:22 +0100 Subject: [PATCH 10/30] Use version to handle all KS versions some which are missing .__version__ attribute. --- .../sorters/external/kilosort4.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index cba7e65517..ed41baeff9 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version PathType = Union[str, Path] @@ -129,9 +130,8 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - import kilosort as ks - - return ks.__version__ + """kilosort version <0.0.10 is always '4' z""" + return version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @@ -216,6 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + # TODO: save_preprocessed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -225,9 +226,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): device=device, save_preprocessed_copy=False, ) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) else: ops = initialize_ops( settings=settings, @@ -237,6 +235,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): invert_sign=invert_sign, device=device, ) + + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): + # TODO: shift, scaled added + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) + else: n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) @@ -259,10 +264,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? invert_sign=invert, dtype=dtype, - tmin=tmin, + tmin=tmin, # TODO: exposing tmin, max? tmax=tmax, artifact_threshold=artifact, - file_object=file_object, + file_object=file_object, # TODO: exposing shift, scale when skipping preprocessing? ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) From 69e72bf0ddfafe42577959e35eac96f184acb727 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 02:06:20 +0100 Subject: [PATCH 11/30] Remove unused vars that were left over I think from prev KS versions. --- src/spikeinterface/sorters/external/kilosort4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index ed41baeff9..9320022a20 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -94,11 +94,9 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", - "scaleproc": "int16 scaling of whitened data, if None set to 200.", "torch_device": "Select the torch device auto/cuda/cpu", } From c3b2bdda3d2f2f1db009302529c6c9b50a3781b9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 12:19:16 +0100 Subject: [PATCH 12/30] Use importlib version instead of .__version__ --- src/spikeinterface/sorters/external/kilosort4.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 9320022a20..65f1483348 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,7 +6,6 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase -from importlib.metadata import version PathType = Union[str, Path] @@ -129,7 +128,10 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): """kilosort version <0.0.10 is always '4' z""" - return version("kilosort") + # Note this import clashes with version! + from importlib.metadata import version as importlib_version + + return importlib_version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): From 52457224b0c724e5c0ee4f5d1e659ae7c3159b91 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 12:28:49 +0100 Subject: [PATCH 13/30] Add kilosort test script and CI workflow. --- .github/workflows/test_kilosort4.yml | 61 +++ .../temp_test_file_dir/test_kilosort4_new.py | 472 ++++++++++++++++++ 2 files changed, 533 insertions(+) create mode 100644 .github/workflows/test_kilosort4.yml create mode 100644 src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml new file mode 100644 index 0000000000..8e57f79786 --- /dev/null +++ b/.github/workflows/test_kilosort4.yml @@ -0,0 +1,61 @@ +name: Testing Kilosort4 + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +# env: +# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} +# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + +# concurrency: # Cancel previous workflows on the same pull request +# group: ${{ github.workflow }}-${{ github.ref }} +# cancel-in-progress: true + +jobs: + run: + name: ${{ matrix.os }} Python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support + os: [ubuntu-latest] # TODO: macos-13, windows-latest, + ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR + steps: + - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install packages + # TODO: maybe dont need full? + run: | + pip install -e .[test] + # git config --global user.email "CI@example.com" + # git config --global user.name "CI Almighty" + # pip install tabulate + shell: bash + + - name: Install Kilosort + run: | + pip install kilosort==${{ matrix.ks_version }} + shell: bash + + - name: Run new kilosort4 tests + # run: chmod +x .github/test_kilosort4.sh + # TODO: figure out the paths to be able to run this by calling the file directly + run: | + pytest -k test_kilosort4_new --durations=0 + shell: bash + +# TODO: pip install -e .[full,dev] is failing # +#The conflict is caused by: +# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs" +# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test" diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py new file mode 100644 index 0000000000..0fb9841728 --- /dev/null +++ b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py @@ -0,0 +1,472 @@ +import copy +from typing import Any +import spikeinterface.full as si +import numpy as np +import torch +import kilosort +from kilosort.io import load_probe +import pandas as pd + +import pytest +from probeinterface.io import write_prb +from kilosort.parameters import DEFAULT_SETTINGS +from packaging.version import parse +from importlib.metadata import version + +# TODO: duplicate_spike_bins to duplicate_spike_ms +# TODO: write an issue on KS about bin! vs bin_ms! +# TODO: expose tmin and tmax +# TODO: expose save_preprocessed_copy +# TODO: make here a log of all API changes (or on kilosort4.py) +# TODO: try out longer recordings and do some benchmarking tests.. +# TODO: expose tmin and tmax +# There is no way to skip HP spatial filter +# might as well expose tmin and tmax +# might as well expose preprocessing save (across the two functions that use it) +# BinaryFilter added scale and shift as new arguments recently +# test with docker +# test all params once +# try and read func / class object to see kwargs +# Shift and scale are also taken as a function on BinaryFilter. Do we want to apply these even when +# do kilosort preprocessing is false? probably +# TODO: find a test case for the other annoying ones (larger recording, variable amplitude) +# TODO: test docker +# TODO: test multi-segment recording +# TODO: test do correction, skip preprocessing +# TODO: can we rename 'save_extra_kwargs' to 'save_extra_vars'. Currently untested. +# nt : # TODO: can't kilosort figure this out from sampling rate? +# TODO: also test runtimes +# TODO: test skip preprocessing separately +# TODO: the pure default case is not tested +# TODO: shift and scale - this is also added to BinaryFilter + +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # TODO: ignore some of these +# "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. + + +PARAMS_TO_TEST = [ + # Not tested + # ("torch_device", "auto") + # Stable across KS version 4.0.01 - 4.0.12 + ("change_nothing", None), + ("nblocks", 0), + ("do_CAR", False), + ("batch_size", 42743), # Q: how much do these results change with batch size? + ("Th_universal", 12), + ("Th_learned", 14), + ("invert_sign", True), + ("nt", 93), + ("nskip", 1), + ("whitening_range", 16), + ("sig_interp", 5), + ("nt0min", 25), + ("dmin", 15), + ("dminx", 16), + ("min_template_size", 15), + ("template_sizes", 10), + ("nearest_chans", 8), + ("nearest_templates", 35), + ("max_channel_distance", 5), + ("templates_from_data", False), + ("n_templates", 10), + ("n_pcs", 3), + ("Th_single_ch", 4), + ("acg_threshold", 0.001), + ("x_centers", 5), + ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS + ("binning_depth", 1), + ("artifact_threshold", 200), + ("ccg_threshold", 1e9), + ("cluster_downsampling", 1e9), + ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! +] + +# Update PARAMS_TO_TEST with version-dependent kwargs +if parse(version("kilosort")) >= parse("4.0.12"): + pass # TODO: expose? +# PARAMS_TO_TEST.extend( +# [ +# ("save_preprocessed_copy", False), +# ] +# ) +if parse(version("kilosort")) >= parse("4.0.11"): + PARAMS_TO_TEST.extend( + [ + ("shift", 1e9), + ("scale", -1e9), + ] + ) +if parse(version("kilosort")) == parse("4.0.9"): + # bug in 4.0.9 for "nblocks=0" + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] + +if parse(version("kilosort")) >= parse("4.0.8"): + PARAMS_TO_TEST.extend( + [ + ("drift_smoothing", [250, 250, 250]), + ] + ) +if parse(version("kilosort")) <= parse("4.0.6"): + # AFAIK this parameter was always unused in KS (that's why it was removed) + PARAMS_TO_TEST.extend( + [ + ("cluster_pcs", 1e9), + ] + ) +if parse(version("kilosort")) <= parse("4.0.3"): + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] + + +class TestKilosort4Long: + + # Fixtures ###### + @pytest.fixture(scope="session") + def recording_and_paths(self, tmp_path_factory): + """ """ + tmp_path = tmp_path_factory.mktemp("kilosort4_tests") + + np.random.seed(0) # TODO: check below... + + recording = self._get_ground_truth_recording() + + paths = self._save_ground_truth_recording(recording, tmp_path) + + return (recording, paths) + + @pytest.fixture(scope="session") + def default_results(self, recording_and_paths): + """ """ + recording, paths = recording_and_paths + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths) + + defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=defaults_ks_output_dir, + ) + + default_results = self._get_sorting_output(defaults_ks_output_dir) + + return default_results + + # Tests ###### + def test_params_to_test(self): + """ + Test that all parameters in PARAMS_TO_TEST are + different than the default value used in Kilosort, otherwise + there is no point to the test. + + TODO: need to use _default_params vs. DEFAULT_SETTINGS + depending on decision + + TODO: write issue on this, we hope it will be on DEFAULT_SETTINGS + TODO: duplicate_spike_ms in POSTPROCESSING but seems unused? + """ + for parameter in PARAMS_TO_TEST: + + param_key, param_value = parameter + + if param_key == "change_nothing": + continue + + if param_key not in RUN_KILOSORT_ARGS: + assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + + def test_default_settings_all_represented(self): + """ + Test that every entry in DEFAULT_SETTINGS is tested in + PARAMS_TO_TEST, otherwise we are missing settings added + on the KS side. + """ + tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + + for param_key in DEFAULT_SETTINGS: + + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) + def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): + """ """ + recording, paths = recording_and_paths + param_key, param_value = parameter + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + extra_ks_settings = {} + if param_key == "binning_depth": + extra_ks_settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + extra_ks_settings.update({param_key: param_value}) + run_kilosort_kwargs = {} + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + **run_kilosort_kwargs, + ) + + extra_si_settings = {} + if param_key != "change_nothing": + extra_si_settings.update({param_key: param_value}) + + if param_key == "binning_depth": + extra_si_settings.update({"nblocks": 5}) + + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + + assert all( + results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0] + ), f"{param_key} cluster assignment different" + assert all( + results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1] + ), f"{param_key} cluster quality different" # TODO: check pandas probably better way + + # This is saved on the SI side so not an extremely + # robust addition, but it can't hurt. + if param_key != "change_nothing": + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value + + # Finally, check out test parameters actually changes stuff! + if parse(version("kilosort")) > parse("4.0.4"): + self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key) + + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): + """ """ + recording, paths = recording_and_paths + + kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=True, + ) + + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6}) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + + assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]) + assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]) + + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): + """ """ + recording = self._get_ground_truth_recording() + + # We need to filter and whiten the recording here to KS takes forever. + # Do this in a way differnt to KS. + recording = si.highpass_filter(recording, 300) + recording = si.whiten(recording, mode="local", apply_mean=False) + + paths = self._save_ground_truth_recording(recording, tmp_path) + + kilosort_default_output_dir = tmp_path / "kilosort_default_output_dir" + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_default_output_dir, + do_CAR=False, + ) + + # Now the tricky bit, we need to turn off preprocessing in kilosort. + # This is not exposed by run_kilosort() arguments (at 4.0.12 at least) + # and so we need to monkeypatch the internal functions. The easiest + # thing to do would be to set `get_highpass_filter()` and + # `get_whitening_matrix()` to return `None` so these steps are skipped + # in BinaryFilter. Unfortunately the ops saving machinery requires + # these to be torch arrays and will error otherwise, so instead + # we must set the filter (in frequency space) and whitening matrix + # to unity operations so the filter and whitening do nothing. It is + # also required to turn off motion correection to avoid some additional + # magic KS is doing at the whitening step when motion correction is on. + fake_filter = np.ones(60122, dtype="float32") # TODO: hard coded + fake_filter = torch.from_numpy(fake_filter).to("cpu") + + fake_white_matrix = np.eye(recording.get_num_channels(), dtype="float32") + fake_white_matrix = torch.from_numpy(fake_white_matrix).to("cpu") + + def fake_fft_highpass(*args, **kwargs): + return fake_filter + + def fake_get_whitening_matrix(*args, **kwargs): + return fake_white_matrix + + def fake_fftshift(X, dim): + return X + + monkeypatch.setattr("kilosort.io.fft_highpass", fake_fft_highpass) + monkeypatch.setattr("kilosort.preprocessing.get_whitening_matrix", fake_get_whitening_matrix) + monkeypatch.setattr("kilosort.io.fftshift", fake_fftshift) + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=False, + ) + + monkeypatch.undo() + + # Now, run kilosort through spikeinterface with the same options. + spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0}) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=False, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) + + default_results = self._get_sorting_output(kilosort_default_output_dir) + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + # Check that out intervention actually make some difference to KS output + # (or this test would do nothing). Then check SI and KS outputs with + # preprocessing skipped are identical. + assert not np.array_equal(default_results["ks"]["st"], results["ks"]["st"]) + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + + # Helpers ###### + def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): + """ """ + if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: + num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size + num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size + + if param_key == "change_nothing": + # TODO: lol + assert ( + (results["si"]["st"].size == default_results["ks"]["st"].size) + and num_clus == num_clus_default + and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} changed somehow!." + else: + assert ( + (results["si"]["st"].size != default_results["ks"]["st"].size) + or num_clus != num_clus_default + or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) + ), f"{param_key} results did not change with parameter change." + + def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None): + """ """ + # dont actually run KS here because we will overwrite the defaults! + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } + + if extra_settings is not None: + settings.update(extra_settings) + + ks_format_probe = load_probe(paths["probe_path"]) + + return settings, ks_format_probe + + def _get_spikeinterface_settings(self, extra_settings=None): + """ """ + # dont actually run here. + settings = copy.deepcopy(DEFAULT_SETTINGS) + + if extra_settings is not None: + settings.update(extra_settings) + + for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # TODO: check tmin and tmax + settings.pop(name) + + return settings + + def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: + """ """ + results = { + "si": {}, + "ks": {}, + } + if kilosort_output_dir: + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = pd.read_table(kilosort_output_dir / "cluster_group.tsv") + + if spikeinterface_output_dir: + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = pd.read_table(spikeinterface_output_dir / "sorter_output" / "cluster_group.tsv") + + return results + + def _get_ground_truth_recording(self): + """ """ + # Chosen so all parameter changes to indeed change the output + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording + + def _save_ground_truth_recording(self, recording, tmp_path): + """ """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } + + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths From ede9dd482163728901dd118973c86d946ffd5f16 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 14:20:43 +0100 Subject: [PATCH 14/30] Fix save_preprocesed copy, argument mispelled. --- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 65f1483348..449ddfbff1 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -216,7 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - # TODO: save_preprocessed_copy added + # TODO: save_preprocesed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -224,7 +224,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, invert_sign=invert_sign, device=device, - save_preprocessed_copy=False, + save_preprocesed_copy=False, ) else: ops = initialize_ops( diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py index 0fb9841728..e4d48a1344 100644 --- a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py +++ b/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py @@ -13,6 +13,7 @@ from packaging.version import parse from importlib.metadata import version +# TODO: save_preprocesed_copy is misspelled in KS4. # TODO: duplicate_spike_bins to duplicate_spike_ms # TODO: write an issue on KS about bin! vs bin_ms! # TODO: expose tmin and tmax From 9570c9273b4f86bd800120c6d05096ebfc82e85d Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 14:36:55 +0100 Subject: [PATCH 15/30] Fix NT format for BinaryFiltered, double-check all again --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 449ddfbff1..28a3c3ffa3 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -255,7 +255,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): filename=ops["filename"], n_chan_bin=n_chan_bin, fs=fs, - nT=NT, + NT=NT, nt=nt, nt0min=twav_min, chan_map=chan_map, From a8489a50a0d4ccaa1c6e75307b73fcae7a8c4bc2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 26 Jun 2024 15:25:08 +0100 Subject: [PATCH 16/30] Add CI to test all kilosort4 versions. --- .github/scripts/README.MD | 2 + .github/scripts/check_kilosort4_releases.py | 20 ++++ .../scripts/test_kilosort4_ci.py | 106 +++++++++++++++++- .github/workflows/test_kilosort4.yml | 63 ++++++----- conftest.py | 7 +- 5 files changed, 170 insertions(+), 28 deletions(-) create mode 100644 .github/scripts/README.MD create mode 100644 .github/scripts/check_kilosort4_releases.py rename src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py => .github/scripts/test_kilosort4_ci.py (83%) diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..3d04d6948a --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,20 @@ +import os +import re +from pathlib import Path +import requests +import json + + +def get_pypi_versions(package_name): + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + return list(sorted(data["releases"].keys())) + + +if __name__ == "__main__": + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + json.dump(versions, f) diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/.github/scripts/test_kilosort4_ci.py similarity index 83% rename from src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py rename to .github/scripts/test_kilosort4_ci.py index e4d48a1344..4684038bd0 100644 --- a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -12,6 +12,14 @@ from kilosort.parameters import DEFAULT_SETTINGS from packaging.version import parse from importlib.metadata import version +from inspect import signature +from kilosort.run_kilosort import (set_files, initialize_ops, + compute_preprocessing, + compute_drift_correction, detect_spikes, + cluster_spikes, save_sorting, + get_run_parameters, ) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered +from kilosort.parameters import DEFAULT_SETTINGS # TODO: save_preprocesed_copy is misspelled in KS4. # TODO: duplicate_spike_bins to duplicate_spike_ms @@ -190,6 +198,102 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + def test_set_files_arguments(self): + self._check_arguments( + set_files, + ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + ) + + def test_initialize_ops_arguments(self): + + expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] + + if parse(version("kilosort")) >= parse("4.0.12"): + expected_arguments.append("save_preprocesed_copy") + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments( + compute_preprocessing, + ["ops", "device", "tic0", "file_object"] + ) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, + ["ops", "device", "tic0", "progress_bar", "file_object"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments( + detect_spikes, + ["ops", "device", "bfile", "tic0", "progress_bar"] + ) + + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, + ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_save_sorting_arguments(self): + + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + if parse(version("kilosort")) > parse("4.0.11"): + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments( + save_sorting, + expected_arguments + ) + + def test_get_run_parameters(self): + self._check_arguments( + get_run_parameters, + ["ops"] + ) + + def test_load_probe_parameters(self): + self._check_arguments( + load_probe, + ["probe_path"] + ) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments( + RecordingExtractorAsArray, + ["recording_extractor"] + ) + + def test_binary_filtered_arguments(self): + + expected_arguments = [ + "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", + "chan_map", "hp_filter", "whiten_mat", "dshift", + "device", "do_CAR", "artifact_threshold", "invert_sign", + "dtype", "tmin", "tmax", "file_object" + ] + + if parse(version("kilosort")) >= parse("4.0.11"): + expected_arguments.pop(-1) + expected_arguments.extend(["shift", "scale", "file_object"]) + + self._check_arguments( + BinaryFiltered, + expected_arguments + ) + + def _check_arguments(self, object_, expected_arguments): + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): """ """ @@ -381,7 +485,7 @@ def fake_fftshift(X, dim): # Helpers ###### def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): """ """ - if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: + if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]: num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 8e57f79786..c216be20d0 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -9,38 +9,56 @@ on: branches: - main -# env: -# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} -# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} +jobs: + versions: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 -# concurrency: # Cancel previous workflows on the same pull request -# group: ${{ github.workflow }}-${{ github.ref }} -# cancel-in-progress: true + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 -jobs: - run: - name: ${{ matrix.os }} Python ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install requests + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support - os: [ubuntu-latest] # TODO: macos-13, windows-latest, - ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: - - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install packages - # TODO: maybe dont need full? + - name: Install SpikeInterface run: | pip install -e .[test] - # git config --global user.email "CI@example.com" - # git config --global user.name "CI Almighty" - # pip install tabulate shell: bash - name: Install Kilosort @@ -49,13 +67,6 @@ jobs: shell: bash - name: Run new kilosort4 tests - # run: chmod +x .github/test_kilosort4.sh - # TODO: figure out the paths to be able to run this by calling the file directly run: | - pytest -k test_kilosort4_new --durations=0 + pytest .github/scripts/test_kilosort4_ci.py shell: bash - -# TODO: pip install -e .[full,dev] is failing # -#The conflict is caused by: -# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs" -# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test" diff --git a/conftest.py b/conftest.py index c4bac6628a..8c06830d25 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -28,7 +29,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: # TODO: make a note on this, check with Herberto its okay. + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: From 159e2b0a92b87ebaddedbf12cc68062bd0e5e5eb Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 01:33:48 +0100 Subject: [PATCH 17/30] Tidying up tests and removing comments from kilosort4.py. --- .github/scripts/test_kilosort4_ci.py | 442 ++++++++++-------- conftest.py | 2 +- .../sorters/external/kilosort4.py | 14 +- 3 files changed, 247 insertions(+), 211 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 4684038bd0..8a455a41fe 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -1,3 +1,23 @@ +""" +This file tests the SpikeInterface wrapper of the Kilosort4. The general logic +of the tests are: +- Change every exposed parameter one at a time (PARAMS_TO_TEST). Check that + the result of the SpikeInterface wrapper and Kilosort run natively are + the same. The SpikeInterface wrapper is non-trivial and decomposes the + kilosort pipeline to allow additions such as skipping preprocessing. Therefore, + the idea is that is it safer to rely on the output directly rather than + try monkeypatching. One thing can could be better tested is parameter + changes when skipping KS4 preprocessing is true, because this takes a slightly + different path through the kilosort4.py wrapper logic. + This also checks that changing the parameter changes the test output from default + on our test case (otherwise, the test could not detect a failure). This is possible + for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + +- Test that kilosort functions called from `kilosort4.py` wrapper have the expected + input signatures + +- Do some tests to check all KS4 parameters are tested against. +""" import copy from typing import Any import spikeinterface.full as si @@ -20,47 +40,21 @@ get_run_parameters, ) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered from kilosort.parameters import DEFAULT_SETTINGS +from kilosort import preprocessing as ks_preprocessing -# TODO: save_preprocesed_copy is misspelled in KS4. -# TODO: duplicate_spike_bins to duplicate_spike_ms -# TODO: write an issue on KS about bin! vs bin_ms! -# TODO: expose tmin and tmax -# TODO: expose save_preprocessed_copy -# TODO: make here a log of all API changes (or on kilosort4.py) -# TODO: try out longer recordings and do some benchmarking tests.. -# TODO: expose tmin and tmax -# There is no way to skip HP spatial filter -# might as well expose tmin and tmax -# might as well expose preprocessing save (across the two functions that use it) -# BinaryFilter added scale and shift as new arguments recently -# test with docker -# test all params once -# try and read func / class object to see kwargs -# Shift and scale are also taken as a function on BinaryFilter. Do we want to apply these even when -# do kilosort preprocessing is false? probably -# TODO: find a test case for the other annoying ones (larger recording, variable amplitude) -# TODO: test docker -# TODO: test multi-segment recording -# TODO: test do correction, skip preprocessing -# TODO: can we rename 'save_extra_kwargs' to 'save_extra_vars'. Currently untested. -# nt : # TODO: can't kilosort figure this out from sampling rate? -# TODO: also test runtimes -# TODO: test skip preprocessing separately -# TODO: the pure default case is not tested -# TODO: shift and scale - this is also added to BinaryFilter - -RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # TODO: ignore some of these +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. - +# Setup Params to test #### PARAMS_TO_TEST = [ # Not tested # ("torch_device", "auto") + # Stable across KS version 4.0.01 - 4.0.12 ("change_nothing", None), ("nblocks", 0), ("do_CAR", False), - ("batch_size", 42743), # Q: how much do these results change with batch size? + ("batch_size", 42743), ("Th_universal", 12), ("Th_learned", 14), ("invert_sign", True), @@ -80,14 +74,15 @@ ("n_templates", 10), ("n_pcs", 3), ("Th_single_ch", 4), - ("acg_threshold", 0.001), ("x_centers", 5), - ("duplicate_spike_bins", 5), # TODO: why is this not erroring, it is deprecated. issue on KS ("binning_depth", 1), + # Note: These don't change the results from + # default when applied to the test case. ("artifact_threshold", 200), - ("ccg_threshold", 1e9), - ("cluster_downsampling", 1e9), - ("duplicate_spike_bins", 5), # TODO: this is depcrecated and changed to _ms in 4.0.13! + ("ccg_threshold", 1e12), + ("acg_threshold", 1e12), + ("cluster_downsampling", 2), + ("duplicate_spike_bins", 5), ] # Update PARAMS_TO_TEST with version-dependent kwargs @@ -131,11 +126,13 @@ class TestKilosort4Long: # Fixtures ###### @pytest.fixture(scope="session") def recording_and_paths(self, tmp_path_factory): - """ """ + """ + Create a ground-truth recording, and save it to binary + so KS4 can run on it. Fixture is set up once and shared between + all tests. + """ tmp_path = tmp_path_factory.mktemp("kilosort4_tests") - np.random.seed(0) # TODO: check below... - recording = self._get_ground_truth_recording() paths = self._save_ground_truth_recording(recording, tmp_path) @@ -144,10 +141,17 @@ def recording_and_paths(self, tmp_path_factory): @pytest.fixture(scope="session") def default_results(self, recording_and_paths): - """ """ + """ + Because we check each parameter at a time and check the + KS4 and SpikeInterface versions match, if changing the parameter + had no effect as compared to default then the test would not test + anything. Therefore, the default results are run once and stored, + to check changing params indeed changes the results during testing. + This is possibly for nearly all parameters. + """ recording, paths = recording_and_paths - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -162,18 +166,46 @@ def default_results(self, recording_and_paths): return default_results - # Tests ###### - def test_params_to_test(self): + def _get_ground_truth_recording(self): """ - Test that all parameters in PARAMS_TO_TEST are - different than the default value used in Kilosort, otherwise - there is no point to the test. + A ground truth recording chosen to be as small as possible (for speed). + But contain enough information so that changing most parameters + changes the results. + """ + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording - TODO: need to use _default_params vs. DEFAULT_SETTINGS - depending on decision + def _save_ground_truth_recording(self, recording, tmp_path): + """ + Save the recording and its probe to file, so it can be + loaded by KS4. + """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } - TODO: write issue on this, we hope it will be on DEFAULT_SETTINGS - TODO: duplicate_spike_ms in POSTPROCESSING but seems unused? + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths + + # Tests ###### + def test_params_to_test(self): + """ + Test that all values in PARAMS_TO_TEST are + different to the default values used in Kilosort, + otherwise there is no point to the test. """ for parameter in PARAMS_TO_TEST: @@ -198,6 +230,7 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + # Testing Arguments ### def test_set_files_arguments(self): self._check_arguments( set_files, @@ -205,7 +238,6 @@ def test_set_files_arguments(self): ) def test_initialize_ops_arguments(self): - expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] if parse(version("kilosort")) >= parse("4.0.12"): @@ -234,7 +266,6 @@ def test_detect_spikes_arguments(self): ["ops", "device", "bfile", "tic0", "progress_bar"] ) - def test_cluster_spikes_arguments(self): self._check_arguments( cluster_spikes, @@ -242,7 +273,6 @@ def test_cluster_spikes_arguments(self): ) def test_save_sorting_arguments(self): - expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] if parse(version("kilosort")) > parse("4.0.11"): @@ -272,7 +302,6 @@ def test_recording_extractor_as_array_arguments(self): ) def test_binary_filtered_arguments(self): - expected_arguments = [ "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", "chan_map", "hp_filter", "whiten_mat", "dshift", @@ -294,27 +323,23 @@ def _check_arguments(self, object_, expected_arguments): obj_arguments = list(sig.parameters.keys()) assert expected_arguments == obj_arguments + # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): - """ """ + """ + Given a recording, paths to raw data, and a parameter to change, + run KS4 natively and within the SpikeInterface wrapper with the + new parameter value (all other values default) and + check the outputs are the same. + """ recording, paths = recording_and_paths param_key, param_value = parameter + # Setup parameters for KS4 and run it natively kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - extra_ks_settings = {} - if param_key == "binning_depth": - extra_ks_settings.update({"nblocks": 5}) - - if param_key in RUN_KILOSORT_ARGS: - run_kilosort_kwargs = {param_key: param_value} - else: - if param_key != "change_nothing": - extra_ks_settings.update({param_key: param_value}) - run_kilosort_kwargs = {} - - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_ks_settings) + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) kilosort.run_kilosort( settings=settings, @@ -324,14 +349,9 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet **run_kilosort_kwargs, ) - extra_si_settings = {} - if param_key != "change_nothing": - extra_si_settings.update({param_key: param_value}) + # Setup Parameters for SI and KS4 through SI + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) - if param_key == "binning_depth": - extra_si_settings.update({"nblocks": 5}) - - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings=extra_si_settings) si.run_sorter( "kilosort4", recording, @@ -340,36 +360,41 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet **spikeinterface_settings, ) + # Get the results and check they match results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" - assert all( - results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0] - ), f"{param_key} cluster assignment different" - assert all( - results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1] - ), f"{param_key} cluster quality different" # TODO: check pandas probably better way - - # This is saved on the SI side so not an extremely - # robust addition, but it can't hurt. + # Check the ops file in KS4 output is as expected. This is saved on the + # SI side so not an extremely robust addition, but it can't hurt. if param_key != "change_nothing": ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) ops = ops.tolist() # strangely this makes a dict assert ops[param_key] == param_value - # Finally, check out test parameters actually changes stuff! + # Finally, check out test parameters actually change the output of + # KS4, ensuring our tests are actually doing something. This is not + # done prior to 4.0.4 because a number of parameters seem to stop + # having an effect. This is probably due to small changes in their + # behaviour, and the test file chosen here. if parse(version("kilosort")) > parse("4.0.4"): - self._check_test_parameters_are_actually_changing_the_output(results, default_results, param_key) + self._check_test_parameters_are_changing_the_output(results, default_results, param_key) def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): - """ """ + """ + Test the SpikeInterface wrappers `do_correction` argument. We set + `nblocks=0` for KS4 native, turning off motion correction. Then + we run KS$ through SpikeInterface with `do_correction=False` but + `nblocks=1` (KS4 default) - checking that `do_correction` overrides + this and the result matches KS4 when run without motion correction. + """ recording, paths = recording_and_paths - kilosort_output_dir = tmp_path / "kilosort_output_dir" # TODO: a lost of copying here + kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "nblocks", 0) kilosort.run_kilosort( settings=settings, @@ -379,7 +404,7 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): do_CAR=True, ) - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 6}) + spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) si.run_sorter( "kilosort4", recording, @@ -392,186 +417,199 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + @pytest.mark.parametrize("param_to_test", [ + ("change_nothing", None), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ]) + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): + """ + Test that skipping KS4 preprocessing works as expected. Run + KS4 natively, monkeypatching the relevant preprocessing functions + such that preprocessing is not performed. Then run in SpikeInterface + with `skip_kilosort_preprocessing=True` and check the outputs match. + + Run with a few randomly chosen parameters to check these are propagated + under this condition. - assert all(results["ks"]["clus"].iloc[:, 0] == results["si"]["clus"].iloc[:, 0]) - assert all(results["ks"]["clus"].iloc[:, 1] == results["si"]["clus"].iloc[:, 1]) + TODO + ---- + It would be nice to check a few additional parameters here. Screw it! + """ + param_key, param_value = param_to_test - def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch): - """ """ recording = self._get_ground_truth_recording() # We need to filter and whiten the recording here to KS takes forever. - # Do this in a way differnt to KS. + # Do this in a way different to KS. recording = si.highpass_filter(recording, 300) recording = si.whiten(recording, mode="local", apply_mean=False) paths = self._save_ground_truth_recording(recording, tmp_path) - kilosort_default_output_dir = tmp_path / "kilosort_default_output_dir" kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - ks_settings, ks_format_probe = self._run_kilosort_with_kilosort(recording, paths, extra_settings={"nblocks": 0}) + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] - kilosort.run_kilosort( - settings=ks_settings, - probe=ks_format_probe, - data_dtype="float32", - results_dir=kilosort_default_output_dir, - do_CAR=False, - ) + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass - # Now the tricky bit, we need to turn off preprocessing in kilosort. - # This is not exposed by run_kilosort() arguments (at 4.0.12 at least) - # and so we need to monkeypatch the internal functions. The easiest - # thing to do would be to set `get_highpass_filter()` and - # `get_whitening_matrix()` to return `None` so these steps are skipped - # in BinaryFilter. Unfortunately the ops saving machinery requires - # these to be torch arrays and will error otherwise, so instead - # we must set the filter (in frequency space) and whitening matrix - # to unity operations so the filter and whitening do nothing. It is - # also required to turn off motion correection to avoid some additional - # magic KS is doing at the whitening step when motion correction is on. - fake_filter = np.ones(60122, dtype="float32") # TODO: hard coded - fake_filter = torch.from_numpy(fake_filter).to("cpu") - - fake_white_matrix = np.eye(recording.get_num_channels(), dtype="float32") - fake_white_matrix = torch.from_numpy(fake_white_matrix).to("cpu") - - def fake_fft_highpass(*args, **kwargs): - return fake_filter - - def fake_get_whitening_matrix(*args, **kwargs): - return fake_white_matrix - - def fake_fftshift(X, dim): + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass return X - monkeypatch.setattr("kilosort.io.fft_highpass", fake_fft_highpass) - monkeypatch.setattr("kilosort.preprocessing.get_whitening_matrix", fake_get_whitening_matrix) - monkeypatch.setattr("kilosort.io.fftshift", fake_fftshift) + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", + monkeypatch_filter_function) + + ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + ks_settings["nblocks"] = 0 + + # Be explicit here and don't rely on defaults. + do_CAR = param_value if param_key == "do_CAR" else False kilosort.run_kilosort( settings=ks_settings, probe=ks_format_probe, data_dtype="float32", results_dir=kilosort_output_dir, - do_CAR=False, + do_CAR=do_CAR, ) monkeypatch.undo() # Now, run kilosort through spikeinterface with the same options. - spikeinterface_settings = self._get_spikeinterface_settings(extra_settings={"nblocks": 0}) + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + spikeinterface_settings["nblocks"] = 0 + + do_CAR = False if param_key != "do_CAR" else spikeinterface_settings.pop("do_CAR") + si.run_sorter( "kilosort4", recording, remove_existing_folder=True, folder=spikeinterface_output_dir, - do_CAR=False, + do_CAR=do_CAR, skip_kilosort_preprocessing=True, **spikeinterface_settings, ) - default_results = self._get_sorting_output(kilosort_default_output_dir) + # There is a very slight difference caused by the batching between load vs. + # memory file. Because in this test recordings are preprocessed, there are + # some filter edge effects that depend on the chunking in `get_traces()`. + # These are all extremely close (usually just 1 spike, 1 idx different). results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - - # Check that out intervention actually make some difference to KS output - # (or this test would do nothing). Then check SI and KS outputs with - # preprocessing skipped are identical. - assert not np.array_equal(default_results["ks"]["st"], results["ks"]["st"]) - assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) # Helpers ###### - def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): - """ """ - if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]: - num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size - num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size + def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): + """ + If nothing is changed, default vs. results outputs are identical. + Otherwise, check they are not the same. Can't figure out how to get + the skipped three parameters below to change the results on this + small test file. + """ + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling"]: + return + + if param_key == "change_nothing": + assert all( + default_results["ks"]["st"] == results["ks"]["st"] + ) and all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} changed somehow!." + else: + assert not ( + default_results["ks"]["st"].size == results["ks"]["st"].size + ) or not all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} results did not change with parameter change." - if param_key == "change_nothing": - # TODO: lol - assert ( - (results["si"]["st"].size == default_results["ks"]["st"].size) - and num_clus == num_clus_default - and all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) - ), f"{param_key} changed somehow!." - else: - assert ( - (results["si"]["st"].size != default_results["ks"]["st"].size) - or num_clus != num_clus_default - or not all(results["si"]["clus"].iloc[:, 1] == default_results["ks"]["clus"].iloc[:, 1]) - ), f"{param_key} results did not change with parameter change." - - def _run_kilosort_with_kilosort(self, recording, paths, extra_settings=None): - """ """ - # dont actually run KS here because we will overwrite the defaults! + def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): + """ + Function to generate the settings and function inputs to run kilosort. + Note when `binning_depth` is used we need to set `nblocks` high to + get the results to change from default. + + Some settings in KS4 are passed by `settings` dict while others + are through the function, these are split here. + """ settings = { "data_dir": paths["recording_path"], "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } - if extra_settings is not None: - settings.update(extra_settings) + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) - return settings, ks_format_probe + return settings, run_kilosort_kwargs, ks_format_probe - def _get_spikeinterface_settings(self, extra_settings=None): - """ """ - # dont actually run here. + def _get_spikeinterface_settings(self, param_key, param_value): + """ + Generate settings kwargs for running KS4 in SpikeInterface. + See `_get_kilosort_native_settings()` for some details. + """ settings = copy.deepcopy(DEFAULT_SETTINGS) - if extra_settings is not None: - settings.update(extra_settings) + if param_key != "change_nothing": + settings.update({param_key: param_value}) + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) - for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # TODO: check tmin and tmax + for name in ["n_chan_bin", "fs", "tmin", "tmax"]: settings.pop(name) return settings def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ """ + """ + Load the results of sorting into a dict for easy comparison. + """ results = { "si": {}, "ks": {}, } if kilosort_output_dir: results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") - results["ks"]["clus"] = pd.read_table(kilosort_output_dir / "cluster_group.tsv") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") if spikeinterface_output_dir: results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") - results["si"]["clus"] = pd.read_table(spikeinterface_output_dir / "sorter_output" / "cluster_group.tsv") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") return results - - def _get_ground_truth_recording(self): - """ """ - # Chosen so all parameter changes to indeed change the output - num_channels = 32 - recording, _ = si.generate_ground_truth_recording( - durations=[5], - seed=0, - num_channels=num_channels, - num_units=5, - generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), - ) - return recording - - def _save_ground_truth_recording(self, recording, tmp_path): - """ """ - paths = { - "session_scope_tmp_path": tmp_path, - "recording_path": tmp_path / "my_test_recording", - "probe_path": tmp_path / "my_test_probe.prb", - } - - recording.save(folder=paths["recording_path"], overwrite=True) - - probegroup = recording.get_probegroup() - write_prb(paths["probe_path"].as_posix(), probegroup) - - return paths diff --git a/conftest.py b/conftest.py index 8c06830d25..544c2fb6cb 100644 --- a/conftest.py +++ b/conftest.py @@ -29,7 +29,7 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - try: # TODO: make a note on this, check with Herberto its okay. + try: rel_path = Path(item.fspath).relative_to(modules_location) except: continue diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 28a3c3ffa3..8721ce1b89 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -127,8 +127,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <0.0.10 is always '4' z""" - # Note this import clashes with version! + """kilosort version <0.0.10 is always '4'""" from importlib.metadata import version as importlib_version return importlib_version("kilosort") @@ -216,7 +215,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - # TODO: save_preprocesed_copy added ops = initialize_ops( settings=settings, probe=probe, @@ -237,7 +235,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): - # TODO: shift, scaled added n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) @@ -261,22 +258,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): chan_map=chan_map, hp_filter=None, device=device, - do_CAR=do_CAR, # TODO: should this always be False if we are in skipping KS preprocessing land? + do_CAR=do_CAR, invert_sign=invert, dtype=dtype, - tmin=tmin, # TODO: exposing tmin, max? + tmin=tmin, tmax=tmax, artifact_threshold=artifact, - file_object=file_object, # TODO: exposing shift, scale when skipping preprocessing? + file_object=file_object, ) ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches + # bfile.close() # TODO: KS do this after preprocessing? np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) - # if not params["skip_kilosort_preprocessing"]: + if not params["do_correction"]: print("Skipping drift correction.") ops["nblocks"] = 0 From 0817a5b3f10c986db04632fb979e2c30cf501dbc Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 02:30:19 +0100 Subject: [PATCH 18/30] Add tests to check _default_params against KS params. --- .github/scripts/test_kilosort4_ci.py | 25 +++++++++++++++---- .../sorters/external/kilosort4.py | 7 +++--- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 8a455a41fe..ecc931781c 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -26,7 +26,7 @@ import kilosort from kilosort.io import load_probe import pandas as pd - +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter import pytest from probeinterface.io import write_prb from kilosort.parameters import DEFAULT_SETTINGS @@ -230,6 +230,21 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + def test_spikeinterface_defaults_against_kilsort(self): + """ + Here check that all _ + Don't check that every default in KS is exposed in params, + because they change across versions. Instead, this check + is performed here against PARAMS_TO_TEST. + """ + params = copy.deepcopy(Kilosort4Sorter._default_params) + + for key in params.keys(): + # "artifact threshold" is set to `np.inf` if `None` in + # the body of the `Kilosort4Sorter` class. + if key in DEFAULT_SETTINGS and key not in ["artifact_threshold"]: + assert params[key] == DEFAULT_SETTINGS[key], f"{key} is not the same across versions." + # Testing Arguments ### def test_set_files_arguments(self): self._check_arguments( @@ -533,7 +548,7 @@ def _check_test_parameters_are_changing_the_output(self, results, default_result the skipped three parameters below to change the results on this small test file. """ - if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling"]: + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: return if param_key == "change_nothing": @@ -583,7 +598,7 @@ def _get_spikeinterface_settings(self, param_key, param_value): Generate settings kwargs for running KS4 in SpikeInterface. See `_get_kilosort_native_settings()` for some details. """ - settings = copy.deepcopy(DEFAULT_SETTINGS) + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) if param_key != "change_nothing": settings.update({param_key: param_value}) @@ -591,8 +606,8 @@ def _get_spikeinterface_settings(self, param_key, param_value): if param_key == "binning_depth": settings.update({"nblocks": 5}) - for name in ["n_chan_bin", "fs", "tmin", "tmax"]: - settings.pop(name) + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # settings.pop(name) return settings diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8721ce1b89..82c033f61d 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -35,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32, + "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.0.1") else None, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -50,7 +51,7 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7, + "duplicate_spike_bins": 7 if version.parse(importlib_version("kilosort")) >= version.parse("4.0.4") else 15, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -128,8 +129,6 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): """kilosort version <0.0.10 is always '4'""" - from importlib.metadata import version as importlib_version - return importlib_version("kilosort") @classmethod From c8779fc87dfaa6aa1d2bdb72d6fa58ed36c7da7c Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 03:01:37 +0100 Subject: [PATCH 19/30] Skip tests where relevant, try on slightly earlier python version to avoid weird xlabel bug. --- .github/scripts/test_kilosort4_ci.py | 3 +++ .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index ecc931781c..3e74fa708e 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -228,6 +228,8 @@ def test_default_settings_all_represented(self): for param_key in DEFAULT_SETTINGS: if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": + continue assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." def test_spikeinterface_defaults_against_kilsort(self): @@ -434,6 +436,7 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert np.array_equal(results["ks"]["st"], results["si"]["st"]) assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") @pytest.mark.parametrize("param_to_test", [ ("change_nothing", None), ("do_CAR", False), diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index c216be20d0..3ad61c0d2e 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.10"] # TODO: just checking python version is not cause of failing test. os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 82c033f61d..811a6e8452 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -36,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.0.1") else None, + "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.2") else None, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -128,7 +128,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <0.0.10 is always '4'""" + """kilosort version <4.0.10 is always '4'""" return importlib_version("kilosort") @classmethod From 867729102ee5a76f412f1a8e7c025ceefadb7bff Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 09:37:21 +0100 Subject: [PATCH 20/30] Don't support 4.0.4 --- .github/scripts/check_kilosort4_releases.py | 7 +++++++ .github/scripts/test_kilosort4_ci.py | 3 ++- .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 5 +++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 3d04d6948a..9572f88330 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -6,14 +6,21 @@ def get_pypi_versions(package_name): + """ + Make an API call to pypi to retrieve all + available versions of the kilosort package. + """ url = f"https://pypi.org/pypi/{package_name}/json" response = requests.get(url) response.raise_for_status() data = response.json() + versions = list(sorted(data["releases"].keys())) + versions.pop(versions.index("4.0.4")) return list(sorted(data["releases"].keys())) if __name__ == "__main__": + # Get all KS4 versions from pipi and write to file. package_name = "kilosort" versions = get_pypi_versions(package_name) with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 3e74fa708e..c894ed71ff 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -342,7 +342,7 @@ def _check_arguments(self, object_, expected_arguments): # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) - def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): + def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): """ Given a recording, paths to raw data, and a parameter to change, run KS4 natively and within the SpikeInterface wrapper with the @@ -398,6 +398,7 @@ def test_kilosort4(self, recording_and_paths, default_results, tmp_path, paramet if parse(version("kilosort")) > parse("4.0.4"): self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 3ad61c0d2e..03db2b6170 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -44,7 +44,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] # TODO: just checking python version is not cause of failing test. + python-version: ["3.12"] # TODO: just checking python version is not cause of failing test. os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 811a6e8452..55e694a02f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -163,6 +163,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) + if cls.get_sorter_version() == version.parse("4.0.4"): + raise RuntimeError( + "Kilosort version 4.0.4 is not supported" "in SpikeInterface. Please change Kilosort version." + ) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" From 21caaf99bd93e7189725acc3de3079264e79d710 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 10:09:58 +0100 Subject: [PATCH 21/30] Remove support for versions earlier that 4.0.5. --- .github/scripts/check_kilosort4_releases.py | 5 +++-- src/spikeinterface/sorters/external/kilosort4.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 9572f88330..05d8c0c614 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -3,7 +3,7 @@ from pathlib import Path import requests import json - +from packaging.version import parse def get_pypi_versions(package_name): """ @@ -15,8 +15,9 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) + versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] versions.pop(versions.index("4.0.4")) - return list(sorted(data["releases"].keys())) + return versions if __name__ == "__main__": diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 55e694a02f..dba28f7244 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -36,7 +36,7 @@ class Kilosort4Sorter(BaseSorter): "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, - "dminx": 32 if version.parse(importlib_version("kilosort")) > version.parse("4.0.2") else None, + "dminx": 32, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -51,7 +51,7 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7 if version.parse(importlib_version("kilosort")) >= version.parse("4.0.4") else 15, + "duplicate_spike_bins": 7, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -163,9 +163,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) - if cls.get_sorter_version() == version.parse("4.0.4"): + if cls.get_sorter_version() < version.parse("4.0.5"): raise RuntimeError( - "Kilosort version 4.0.4 is not supported" "in SpikeInterface. Please change Kilosort version." + "Kilosort versions before 4.0.5 are not supported" + "in SpikeInterface. " + "Please upgrade Kilosort version." ) sorter_output_folder = sorter_output_folder.absolute() From 9bc18978fbb56917b0f4fe46df7c3bc531f850a4 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 10:40:50 +0100 Subject: [PATCH 22/30] Add packaging to CI dependency. On branch add_kilosort4_wrapper_tests --- .github/scripts/check_kilosort4_releases.py | 1 - .github/workflows/test_kilosort4.yml | 2 +- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 05d8c0c614..de11dc974b 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -16,7 +16,6 @@ def get_pypi_versions(package_name): data = response.json() versions = list(sorted(data["releases"].keys())) versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] - versions.pop(versions.index("4.0.4")) return versions diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 03db2b6170..088dd1a6a4 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | - pip install requests + pip install requests packaging - name: Fetch package versions from PyPI run: | diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index dba28f7244..eb1df7c455 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -163,7 +163,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) - if cls.get_sorter_version() < version.parse("4.0.5"): + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): raise RuntimeError( "Kilosort versions before 4.0.5 are not supported" "in SpikeInterface. " From 23d2c77533a2bc65791bd6d07eda9b8723133c33 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 27 Jun 2024 12:30:05 +0100 Subject: [PATCH 23/30] Add some more documentation to .yml --- .github/workflows/test_kilosort4.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 088dd1a6a4..13d70acf88 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -11,6 +11,8 @@ on: jobs: versions: + # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # and store them in a matrix for the next job. runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} From 32568ca1a9637a7dc167dbf1a56e214dbe13cfb5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 15:46:49 +0100 Subject: [PATCH 24/30] Remove run CI on main, only run on cron job. --- .github/workflows/test_kilosort4.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 13d70acf88..24b2e29440 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -4,10 +4,6 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC - pull_request: - types: [synchronize, opened, reopened] - branches: - - main jobs: versions: From 8580c975e0d26db4006883da7ff2c36a58a5832a Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:42 +0100 Subject: [PATCH 25/30] Update .github/scripts/test_kilosort4_ci.py --- .github/scripts/test_kilosort4_ci.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index c894ed71ff..10855f2120 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -336,6 +336,10 @@ def test_binary_filtered_arguments(self): ) def _check_arguments(self, object_, expected_arguments): + """ + Check that the argument signature of `object_` is as expected + (i..e has not changed across kilosort versions). + """ sig = signature(object_) obj_arguments = list(sig.parameters.keys()) assert expected_arguments == obj_arguments From b3c6680f859d165bc6f4e11ea8d91cfd6c95eaf1 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:52 +0100 Subject: [PATCH 26/30] Update src/spikeinterface/sorters/external/kilosort4.py --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index eb1df7c455..3f7a0f7abe 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -128,7 +128,7 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - """kilosort version <4.0.10 is always '4'""" + """kilosort.__version__ <4.0.10 is always '4'""" return importlib_version("kilosort") @classmethod From 23c39831a9cadba7ab50c88c53536723e93fba2f Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:48:10 +0100 Subject: [PATCH 27/30] Update .github/workflows/test_kilosort4.yml --- .github/workflows/test_kilosort4.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 24b2e29440..95fc30b0b2 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] # TODO: just checking python version is not cause of failing test. + python-version: ["3.12"] os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: From ed9ef3251504a8d2388a5c461e5c8531113ccb09 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 16:04:16 +0100 Subject: [PATCH 28/30] Fix linting. --- .github/scripts/test_kilosort4_ci.py | 2 +- .github/workflows/test_kilosort4.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 10855f2120..3ac8c7dd2b 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -337,7 +337,7 @@ def test_binary_filtered_arguments(self): def _check_arguments(self, object_, expected_arguments): """ - Check that the argument signature of `object_` is as expected + Check that the argument signature of `object_` is as expected (i..e has not changed across kilosort versions). """ sig = signature(object_) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 95fc30b0b2..390bec98be 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.12"] os: [ubuntu-latest] ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: From ae44b4a908855b8495d1d9807fddc73d8452b86a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 20 Aug 2024 21:20:27 +0100 Subject: [PATCH 29/30] Remove 'save_preprocessed' test. --- .github/scripts/test_kilosort4_ci.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 3ac8c7dd2b..e0d1f2a504 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -85,14 +85,6 @@ ("duplicate_spike_bins", 5), ] -# Update PARAMS_TO_TEST with version-dependent kwargs -if parse(version("kilosort")) >= parse("4.0.12"): - pass # TODO: expose? -# PARAMS_TO_TEST.extend( -# [ -# ("save_preprocessed_copy", False), -# ] -# ) if parse(version("kilosort")) >= parse("4.0.11"): PARAMS_TO_TEST.extend( [ From 642eea9b2c1242000dd847701eb89dc533def6be Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 21 Aug 2024 12:55:59 +0100 Subject: [PATCH 30/30] Update KS4 versions to test on, add a warning for the next version. --- .github/scripts/check_kilosort4_releases.py | 10 +++++++++- .github/scripts/kilosort4-latest-version.json | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 .github/scripts/kilosort4-latest-version.json diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index de11dc974b..92e7bf277f 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -4,6 +4,7 @@ import requests import json from packaging.version import parse +import spikeinterface def get_pypi_versions(package_name): """ @@ -15,7 +16,13 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) - versions = [ver for ver in versions if parse(ver) >= parse("4.0.5")] + + assert parse(spikeinterface.__version__) < parse("0.101.1"), ( + "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." + "At version 0.101.1, this should be updated to support newer" + "kilosort verrsions." + ) + versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] return versions @@ -24,4 +31,5 @@ def get_pypi_versions(package_name): package_name = "kilosort" versions = get_pypi_versions(package_name) with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + print(versions) json.dump(versions, f) diff --git a/.github/scripts/kilosort4-latest-version.json b/.github/scripts/kilosort4-latest-version.json new file mode 100644 index 0000000000..03629ff842 --- /dev/null +++ b/.github/scripts/kilosort4-latest-version.json @@ -0,0 +1 @@ +["4.0.10", "4.0.11", "4.0.12", "4.0.5", "4.0.6", "4.0.7", "4.0.8", "4.0.9"]