diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..92e7bf277f --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,35 @@ +import os +import re +from pathlib import Path +import requests +import json +from packaging.version import parse +import spikeinterface + +def get_pypi_versions(package_name): + """ + Make an API call to pypi to retrieve all + available versions of the kilosort package. + """ + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + versions = list(sorted(data["releases"].keys())) + + assert parse(spikeinterface.__version__) < parse("0.101.1"), ( + "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." + "At version 0.101.1, this should be updated to support newer" + "kilosort verrsions." + ) + versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] + return versions + + +if __name__ == "__main__": + # Get all KS4 versions from pipi and write to file. + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + print(versions) + json.dump(versions, f) diff --git a/.github/scripts/kilosort4-latest-version.json b/.github/scripts/kilosort4-latest-version.json new file mode 100644 index 0000000000..03629ff842 --- /dev/null +++ b/.github/scripts/kilosort4-latest-version.json @@ -0,0 +1 @@ +["4.0.10", "4.0.11", "4.0.12", "4.0.5", "4.0.6", "4.0.7", "4.0.8", "4.0.9"] diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py new file mode 100644 index 0000000000..e0d1f2a504 --- /dev/null +++ b/.github/scripts/test_kilosort4_ci.py @@ -0,0 +1,630 @@ +""" +This file tests the SpikeInterface wrapper of the Kilosort4. The general logic +of the tests are: +- Change every exposed parameter one at a time (PARAMS_TO_TEST). Check that + the result of the SpikeInterface wrapper and Kilosort run natively are + the same. The SpikeInterface wrapper is non-trivial and decomposes the + kilosort pipeline to allow additions such as skipping preprocessing. Therefore, + the idea is that is it safer to rely on the output directly rather than + try monkeypatching. One thing can could be better tested is parameter + changes when skipping KS4 preprocessing is true, because this takes a slightly + different path through the kilosort4.py wrapper logic. + This also checks that changing the parameter changes the test output from default + on our test case (otherwise, the test could not detect a failure). This is possible + for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + +- Test that kilosort functions called from `kilosort4.py` wrapper have the expected + input signatures + +- Do some tests to check all KS4 parameters are tested against. +""" +import copy +from typing import Any +import spikeinterface.full as si +import numpy as np +import torch +import kilosort +from kilosort.io import load_probe +import pandas as pd +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter +import pytest +from probeinterface.io import write_prb +from kilosort.parameters import DEFAULT_SETTINGS +from packaging.version import parse +from importlib.metadata import version +from inspect import signature +from kilosort.run_kilosort import (set_files, initialize_ops, + compute_preprocessing, + compute_drift_correction, detect_spikes, + cluster_spikes, save_sorting, + get_run_parameters, ) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered +from kilosort.parameters import DEFAULT_SETTINGS +from kilosort import preprocessing as ks_preprocessing + +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] +# "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. + +# Setup Params to test #### +PARAMS_TO_TEST = [ + # Not tested + # ("torch_device", "auto") + + # Stable across KS version 4.0.01 - 4.0.12 + ("change_nothing", None), + ("nblocks", 0), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_universal", 12), + ("Th_learned", 14), + ("invert_sign", True), + ("nt", 93), + ("nskip", 1), + ("whitening_range", 16), + ("sig_interp", 5), + ("nt0min", 25), + ("dmin", 15), + ("dminx", 16), + ("min_template_size", 15), + ("template_sizes", 10), + ("nearest_chans", 8), + ("nearest_templates", 35), + ("max_channel_distance", 5), + ("templates_from_data", False), + ("n_templates", 10), + ("n_pcs", 3), + ("Th_single_ch", 4), + ("x_centers", 5), + ("binning_depth", 1), + # Note: These don't change the results from + # default when applied to the test case. + ("artifact_threshold", 200), + ("ccg_threshold", 1e12), + ("acg_threshold", 1e12), + ("cluster_downsampling", 2), + ("duplicate_spike_bins", 5), +] + +if parse(version("kilosort")) >= parse("4.0.11"): + PARAMS_TO_TEST.extend( + [ + ("shift", 1e9), + ("scale", -1e9), + ] + ) +if parse(version("kilosort")) == parse("4.0.9"): + # bug in 4.0.9 for "nblocks=0" + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] + +if parse(version("kilosort")) >= parse("4.0.8"): + PARAMS_TO_TEST.extend( + [ + ("drift_smoothing", [250, 250, 250]), + ] + ) +if parse(version("kilosort")) <= parse("4.0.6"): + # AFAIK this parameter was always unused in KS (that's why it was removed) + PARAMS_TO_TEST.extend( + [ + ("cluster_pcs", 1e9), + ] + ) +if parse(version("kilosort")) <= parse("4.0.3"): + PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] + + +class TestKilosort4Long: + + # Fixtures ###### + @pytest.fixture(scope="session") + def recording_and_paths(self, tmp_path_factory): + """ + Create a ground-truth recording, and save it to binary + so KS4 can run on it. Fixture is set up once and shared between + all tests. + """ + tmp_path = tmp_path_factory.mktemp("kilosort4_tests") + + recording = self._get_ground_truth_recording() + + paths = self._save_ground_truth_recording(recording, tmp_path) + + return (recording, paths) + + @pytest.fixture(scope="session") + def default_results(self, recording_and_paths): + """ + Because we check each parameter at a time and check the + KS4 and SpikeInterface versions match, if changing the parameter + had no effect as compared to default then the test would not test + anything. Therefore, the default results are run once and stored, + to check changing params indeed changes the results during testing. + This is possibly for nearly all parameters. + """ + recording, paths = recording_and_paths + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) + + defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=defaults_ks_output_dir, + ) + + default_results = self._get_sorting_output(defaults_ks_output_dir) + + return default_results + + def _get_ground_truth_recording(self): + """ + A ground truth recording chosen to be as small as possible (for speed). + But contain enough information so that changing most parameters + changes the results. + """ + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording + + def _save_ground_truth_recording(self, recording, tmp_path): + """ + Save the recording and its probe to file, so it can be + loaded by KS4. + """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } + + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths + + # Tests ###### + def test_params_to_test(self): + """ + Test that all values in PARAMS_TO_TEST are + different to the default values used in Kilosort, + otherwise there is no point to the test. + """ + for parameter in PARAMS_TO_TEST: + + param_key, param_value = parameter + + if param_key == "change_nothing": + continue + + if param_key not in RUN_KILOSORT_ARGS: + assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + + def test_default_settings_all_represented(self): + """ + Test that every entry in DEFAULT_SETTINGS is tested in + PARAMS_TO_TEST, otherwise we are missing settings added + on the KS side. + """ + tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + + for param_key in DEFAULT_SETTINGS: + + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": + continue + assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + + def test_spikeinterface_defaults_against_kilsort(self): + """ + Here check that all _ + Don't check that every default in KS is exposed in params, + because they change across versions. Instead, this check + is performed here against PARAMS_TO_TEST. + """ + params = copy.deepcopy(Kilosort4Sorter._default_params) + + for key in params.keys(): + # "artifact threshold" is set to `np.inf` if `None` in + # the body of the `Kilosort4Sorter` class. + if key in DEFAULT_SETTINGS and key not in ["artifact_threshold"]: + assert params[key] == DEFAULT_SETTINGS[key], f"{key} is not the same across versions." + + # Testing Arguments ### + def test_set_files_arguments(self): + self._check_arguments( + set_files, + ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + ) + + def test_initialize_ops_arguments(self): + expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] + + if parse(version("kilosort")) >= parse("4.0.12"): + expected_arguments.append("save_preprocesed_copy") + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments( + compute_preprocessing, + ["ops", "device", "tic0", "file_object"] + ) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, + ["ops", "device", "tic0", "progress_bar", "file_object"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments( + detect_spikes, + ["ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, + ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_save_sorting_arguments(self): + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + if parse(version("kilosort")) > parse("4.0.11"): + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments( + save_sorting, + expected_arguments + ) + + def test_get_run_parameters(self): + self._check_arguments( + get_run_parameters, + ["ops"] + ) + + def test_load_probe_parameters(self): + self._check_arguments( + load_probe, + ["probe_path"] + ) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments( + RecordingExtractorAsArray, + ["recording_extractor"] + ) + + def test_binary_filtered_arguments(self): + expected_arguments = [ + "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", + "chan_map", "hp_filter", "whiten_mat", "dshift", + "device", "do_CAR", "artifact_threshold", "invert_sign", + "dtype", "tmin", "tmax", "file_object" + ] + + if parse(version("kilosort")) >= parse("4.0.11"): + expected_arguments.pop(-1) + expected_arguments.extend(["shift", "scale", "file_object"]) + + self._check_arguments( + BinaryFiltered, + expected_arguments + ) + + def _check_arguments(self, object_, expected_arguments): + """ + Check that the argument signature of `object_` is as expected + (i..e has not changed across kilosort versions). + """ + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + + # Full Test #### + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) + def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): + """ + Given a recording, paths to raw data, and a parameter to change, + run KS4 natively and within the SpikeInterface wrapper with the + new parameter value (all other values default) and + check the outputs are the same. + """ + recording, paths = recording_and_paths + param_key, param_value = parameter + + # Setup parameters for KS4 and run it natively + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + **run_kilosort_kwargs, + ) + + # Setup Parameters for SI and KS4 through SI + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) + + # Get the results and check they match + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" + + # Check the ops file in KS4 output is as expected. This is saved on the + # SI side so not an extremely robust addition, but it can't hurt. + if param_key != "change_nothing": + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value + + # Finally, check out test parameters actually change the output of + # KS4, ensuring our tests are actually doing something. This is not + # done prior to 4.0.4 because a number of parameters seem to stop + # having an effect. This is probably due to small changes in their + # behaviour, and the test file chosen here. + if parse(version("kilosort")) > parse("4.0.4"): + self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): + """ + Test the SpikeInterface wrappers `do_correction` argument. We set + `nblocks=0` for KS4 native, turning off motion correction. Then + we run KS$ through SpikeInterface with `do_correction=False` but + `nblocks=1` (KS4 default) - checking that `do_correction` overrides + this and the result matches KS4 when run without motion correction. + """ + recording, paths = recording_and_paths + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "nblocks", 0) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=True, + ) + + spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) + + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + + assert np.array_equal(results["ks"]["st"], results["si"]["st"]) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + @pytest.mark.parametrize("param_to_test", [ + ("change_nothing", None), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ]) + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): + """ + Test that skipping KS4 preprocessing works as expected. Run + KS4 natively, monkeypatching the relevant preprocessing functions + such that preprocessing is not performed. Then run in SpikeInterface + with `skip_kilosort_preprocessing=True` and check the outputs match. + + Run with a few randomly chosen parameters to check these are propagated + under this condition. + + TODO + ---- + It would be nice to check a few additional parameters here. Screw it! + """ + param_key, param_value = param_to_test + + recording = self._get_ground_truth_recording() + + # We need to filter and whiten the recording here to KS takes forever. + # Do this in a way different to KS. + recording = si.highpass_filter(recording, 300) + recording = si.whiten(recording, mode="local", apply_mean=False) + + paths = self._save_ground_truth_recording(recording, tmp_path) + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] + + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass + + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass + return X + + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", + monkeypatch_filter_function) + + ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + ks_settings["nblocks"] = 0 + + # Be explicit here and don't rely on defaults. + do_CAR = param_value if param_key == "do_CAR" else False + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=do_CAR, + ) + + monkeypatch.undo() + + # Now, run kilosort through spikeinterface with the same options. + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + spikeinterface_settings["nblocks"] = 0 + + do_CAR = False if param_key != "do_CAR" else spikeinterface_settings.pop("do_CAR") + + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=do_CAR, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) + + # There is a very slight difference caused by the batching between load vs. + # memory file. Because in this test recordings are preprocessed, there are + # some filter edge effects that depend on the chunking in `get_traces()`. + # These are all extremely close (usually just 1 spike, 1 idx different). + results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) + + # Helpers ###### + def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): + """ + If nothing is changed, default vs. results outputs are identical. + Otherwise, check they are not the same. Can't figure out how to get + the skipped three parameters below to change the results on this + small test file. + """ + if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: + return + + if param_key == "change_nothing": + assert all( + default_results["ks"]["st"] == results["ks"]["st"] + ) and all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} changed somehow!." + else: + assert not ( + default_results["ks"]["st"].size == results["ks"]["st"].size + ) or not all( + default_results["ks"]["clus"] == results["ks"]["clus"] + ), f"{param_key} results did not change with parameter change." + + def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): + """ + Function to generate the settings and function inputs to run kilosort. + Note when `binning_depth` is used we need to set `nblocks` high to + get the results to change from default. + + Some settings in KS4 are passed by `settings` dict while others + are through the function, these are split here. + """ + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} + + ks_format_probe = load_probe(paths["probe_path"]) + + return settings, run_kilosort_kwargs, ks_format_probe + + def _get_spikeinterface_settings(self, param_key, param_value): + """ + Generate settings kwargs for running KS4 in SpikeInterface. + See `_get_kilosort_native_settings()` for some details. + """ + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) + + if param_key != "change_nothing": + settings.update({param_key: param_value}) + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # settings.pop(name) + + return settings + + def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: + """ + Load the results of sorting into a dict for easy comparison. + """ + results = { + "si": {}, + "ks": {}, + } + if kilosort_output_dir: + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + + if spikeinterface_output_dir: + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") + + return results diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml new file mode 100644 index 0000000000..390bec98be --- /dev/null +++ b/.github/workflows/test_kilosort4.yml @@ -0,0 +1,70 @@ +name: Testing Kilosort4 + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + +jobs: + versions: + # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # and store them in a matrix for the next job. + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install requests packaging + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install SpikeInterface + run: | + pip install -e .[test] + shell: bash + + - name: Install Kilosort + run: | + pip install kilosort==${{ matrix.ks_version }} + shell: bash + + - name: Run new kilosort4 tests + run: | + pytest .github/scripts/test_kilosort4_ci.py + shell: bash diff --git a/conftest.py b/conftest.py index ce5e07b47b..5bf7d74527 100644 --- a/conftest.py +++ b/conftest.py @@ -7,6 +7,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -16,7 +17,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: diff --git a/doc/api.rst b/doc/api.rst index 1966b48a37..42f9fec299 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -171,6 +171,7 @@ spikeinterface.preprocessing .. autofunction:: interpolate_bad_channels .. autofunction:: normalize_by_quantile .. autofunction:: notch_filter + .. autofunction:: causal_filter .. autofunction:: phase_shift .. autofunction:: rectify .. autofunction:: remove_artifacts diff --git a/pyproject.toml b/pyproject.toml index 71919c072b..eb2c0f2fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.8,<4.0" +requires-python = ">=3.9,<4.0" classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e65afabaca..fe670cbf3a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -96,14 +96,18 @@ def list_to_string(lst, max_size=6): def _repr_header(self): num_segments = self.get_num_segments() num_channels = self.get_num_channels() - sf_hz = self.get_sampling_frequency() - sf_khz = sf_hz / 1000 dtype = self.get_dtype() total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() - sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" + + sf_hz = self.get_sampling_frequency() + if not sf_hz.is_integer(): + sampling_frequency_repr = f"{sf_hz:f} Hz" + else: + # Khz for high sampling rate and Hz for LFP + sampling_frequency_repr = f"{(sf_hz/1000.0):0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( f"{self.name}: " @@ -422,7 +426,7 @@ def get_time_info(self, segment_index=None) -> dict: return time_kwargs - def get_times(self, segment_index=None): + def get_times(self, segment_index=None) -> np.ndarray: """Get time vector for a recording segment. If the segment has a time_vector, then it is returned. Otherwise @@ -809,12 +813,10 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): BaseSegment.__init__(self) - def get_times(self): + def get_times(self) -> np.ndarray: if self.time_vector is not None: - if isinstance(self.time_vector, np.ndarray): - return self.time_vector - else: - return np.array(self.time_vector) + self.time_vector = np.asarray(self.time_vector) + return self.time_vector else: time_vector = np.arange(self.get_num_samples(), dtype="float64") time_vector /= self.sampling_frequency diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index aad7613d01..b38222391c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -684,3 +684,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: memory = mem_info.total - mem_info.available return memory + + +def is_path_remote(path: str | Path) -> bool: + """ + Returns True if the path is a remote path (e.g., s3:// or gcs://). + + Parameters + ---------- + path : str or Path + The path to check. + + Returns + ------- + bool + Whether the path is a remote path. + """ + return "s3://" in str(path) or "gcs://" in str(path) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index f8ab8a2d3a..57b4a6d4f7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Union, Optional, List, Literal +from typing import Literal from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -27,12 +27,12 @@ def _ensure_seed(seed): def generate_recording( - num_channels: Optional[int] = 2, - sampling_frequency: Optional[float] = 30000.0, - durations: Optional[List[float]] = [5.0, 2.5], - set_probe: Optional[bool] = True, - ndim: Optional[int] = 2, - seed: Optional[int] = None, + num_channels: int = 2, + sampling_frequency: float = 30000.0, + durations: list[float] = [5.0, 2.5], + set_probe: bool | None = True, + ndim: int | None = 2, + seed: int | None = None, ) -> BaseRecording: """ Generate a lazy recording object. @@ -51,7 +51,7 @@ def generate_recording( If true, attaches probe to the returned `Recording` ndim : int, default: 2 The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe. - seed : Optional[int] + seed : int | None, default: None A seed for the np.ramdom.default_rng function Returns @@ -106,7 +106,7 @@ def generate_sorting( num_units : int, default: 5 Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency. + The sampling frequency of the recording in Hz. durations : list, default: [10.325, 3.5] Duration of each segment in s. firing_rates : float, default: 3.0 @@ -189,7 +189,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): ---------- sorting : BaseSorting The sorting object. - sync_event_ratio : float + sync_event_ratio : float, default: 0.3 The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). @@ -237,7 +237,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): def generate_sorting_to_inject( sorting: BaseSorting, - num_samples: List[int], + num_samples: list[int], max_injected_per_unit: int = 1000, injected_rate: float = 0.05, refractory_period_ms: float = 1.5, @@ -251,16 +251,16 @@ def generate_sorting_to_inject( ---------- sorting : BaseSorting The sorting object. - num_samples: list of size num_segments. + num_samples : list[int] of size num_segments. The number of samples in all the segments of the sorting, to generate spike times covering entire the entire duration of the segments. - max_injected_per_unit: int, default 1000 + max_injected_per_unit : int, default: 1000 The maximal number of spikes injected per units. - injected_rate: float, default 0.05 + injected_rate : float, default: 0.05 The rate at which spikes are injected. - refractory_period_ms: float, default 1.5 + refractory_period_ms : float, default: 1.5 The refractory period that should not be violated while injecting new spikes. - seed: int, default None + seed : int, default: None The random seed. Returns @@ -314,13 +314,13 @@ class TransformSorting(BaseSorting): ---------- sorting : BaseSorting The sorting object. - added_spikes_existing_units : np.array (spike_vector) + added_spikes_existing_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for existing units. - added_spikes_new_units: np.array (spike_vector) + added_spikes_new_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for new units. - new_units_ids: list + new_units_ids : list[str, int] | None, default: None The unit_ids that should be added if spikes for new units are added. - refractory_period_ms : float, default None + refractory_period_ms : float | None, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -334,10 +334,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units=None, - added_spikes_new_units=None, - new_unit_ids: Optional[List[Union[str, int]]] = None, - refractory_period_ms: Optional[float] = None, + added_spikes_existing_units: np.array | None = None, + added_spikes_new_units: np.array | None = None, + new_unit_ids: list[str | int] | None = None, + refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) @@ -429,11 +429,11 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting. - sorting2: BaseSorting + sorting2 : BaseSorting The second sorting. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -485,7 +485,7 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe @staticmethod def add_from_unit_dict( - sorting1: BaseSorting, units_dict_list: dict, refractory_period_ms=None + sorting1: BaseSorting, units_dict_list: list[dict] | dict, refractory_period_ms=None ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting with a @@ -495,11 +495,11 @@ def add_from_unit_dict( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - dict_list: list of dict + dict_list : list[dict] | dict A list of dict with unit_ids as keys and spike times as values. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -520,16 +520,18 @@ def from_times_labels( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - times_list: list of array (or array) + times_list : list[np.array] | np.array An array of spike times (in frames). - labels_list: list of array (or array) + labels_list : list[np.array] | np.array An array of spike labels corresponding to the given times. - unit_ids: list or None, default: None + sampling_frequency : float, default: 30000.0 + The sampling frequency of the recording in Hz. + unit_ids : list | None, default: None The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list). - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -593,7 +595,7 @@ def generate_snippets( nafter=44, num_channels=2, wf_folder=None, - sampling_frequency=30000.0, # in Hz + sampling_frequency=30000.0, durations=[10.325, 3.5], #  in s for 2 segments set_probe=True, ndim=2, @@ -615,7 +617,7 @@ def generate_snippets( wf_folder : str | Path | None, default: None Optional folder to save the waveform snippets. If None, snippets are in memory. sampling_frequency : float, default: 30000.0 - The sampling frequency of the snippets. + The sampling frequency of the snippets in Hz. ndim : int, default: 2 The number of dimensions of the probe. num_units : int, default: 5 @@ -801,11 +803,11 @@ def synthesize_random_firings( Parameters ---------- - num_units : int + num_units : int, default: 20 Number of units. - sampling_frequency : float - Sampling rate. - duration : float + sampling_frequency : float, default: 30000.0 + Sampling rate in Hz. + duration : float, default: 60 Duration of the segment in seconds. refractory_period_ms : float Refractory period in ms. @@ -907,13 +909,13 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No ---------- sorting : Original sorting. - num : int + num : int, default: 4 Number of injected units. - max_shift : int + max_shift : int, default: 5 range of the shift in sample. - ratio : float + ratio : float | None, default: None Proportion of original spike in the injected units. - seed : None|int, default: None + seed : int | None, default: None Random seed for creating unit peak shifts. Returns @@ -1070,23 +1072,23 @@ class NoiseGeneratorRecording(BaseRecording): The number of channels. sampling_frequency : float The sampling frequency of the recorder. - durations : List[float] + durations : list[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_levels : float or array, default: 1 + noise_levels : float | np.array, default: 1.0 Std of the white noise (if an array, defined by per channels) - cov_matrix : np.array, default None + cov_matrix : np.array | None, default: None The covariance matrix of the noise - dtype : Optional[Union[np.dtype, str]], default: "float32" + dtype : np.dtype | str | None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. - seed : Optional[int], default: None + seed : int | None, default: None The seed for np.random.default_rng. - strategy : "tile_pregenerated" or "on_the_fly" + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) - noise_block_size : int + noise_block_size : int, default: 30000 Size in sample of noise block. Notes @@ -1099,11 +1101,11 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations: List[float], - noise_levels: float = 1.0, - cov_matrix: Optional[np.array] = None, - dtype: Optional[Union[np.dtype, str]] = "float32", - seed: Optional[int] = None, + durations: list[float], + noise_levels: float | np.array = 1.0, + cov_matrix: np.array | None = None, + dtype: np.dtype | str | None = "float32", + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): @@ -1160,7 +1162,6 @@ def __init__( "sampling_frequency": sampling_frequency, "noise_levels": noise_levels, "cov_matrix": cov_matrix, - "noise_levels": noise_levels, "dtype": dtype, "seed": seed, "strategy": strategy, @@ -1215,9 +1216,9 @@ def get_num_samples(self) -> int: def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size @@ -1271,8 +1272,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 384, - seed: Optional[int] = None, + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: """ @@ -1289,15 +1289,14 @@ def generate_recording_by_size( ---------- full_traces_size_GiB : float The size in gigabytes (GiB) of the recording. - num_channels : int - Number of channels. - seed : int, default: None + seed : int | None, default: None The seed for np.random.default_rng. - strategy : "tile_pregenerated"| "on_the_fly", default: "tile_pregenerated" + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: - * "tile_pregenerated": pregenerate a noise chunk of `noise_block_size` samples and repeat it quickly consuming only one noise block. - * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index. No memory preallocation but a bit more computaion (random) - + * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it + very fast and consume only one noise block. + * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index + no memory preallocation but a bit more computation (random) Returns ------- GeneratorRecording @@ -1543,15 +1542,15 @@ def generate_templates( Cut out in ms before spike peak. ms_after : float Cut out in ms after spike peak. - seed : int or None + seed : int | None A seed for random. dtype : numpy.dtype, default: "float32" Templates dtype - upsample_factor : None or int + upsample_factor : int | None, default: None If not None then template are generated upsampled by this factor. Then a new dimention (axis=3) is added to the template with intermediate inter sample representation. This allow easy random jitter by choising a template this new dim - unit_params : dict of arrays or dict of scalar of dict of tuple + unit_params : dict[np.array] | dict[float] | dict[tuple] | None, default: None An optional dict containing parameters per units. Keys are parameter names: @@ -1568,6 +1567,10 @@ def generate_templates( * array of the same length of units * scalar, then an array is created * tuple, then this difine a range for random values. + mode : "ellipsoid" | "sphere", default: "ellipsoid" + Method used to calculate the distance between unit and channel location. + Ellipsoid injects some anisotropy dependent on unit shape, sphere is equivalent + to Euclidean distance. mode : "sphere" | "ellipsoid", default: "ellipsoid" Mode for how to calculate distances @@ -1694,7 +1697,7 @@ class InjectTemplatesRecording(BaseRecording): ---------- sorting : BaseSorting Sorting object containing all the units and their spike train. - templates : np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] + templates : np.ndarray[n_units, n_samples, n_channels] | np.ndarray[n_units, n_samples, n_oversampling] Array containing the templates to inject for all the units. Shape can be: @@ -1708,13 +1711,13 @@ class InjectTemplatesRecording(BaseRecording): Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). Can be a vector with same shape of spike_vector of the sorting. - parent_recording : BaseRecording | None + parent_recording : BaseRecording | None, default: None The recording over which to add the templates. If None, will default to traces containing all 0. - num_samples : list[int] | int | None + num_samples : list[int] | int | None, default: None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector : np.array or None, default: None. + upsample_vector : np.array | None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. check_borders : bool, default: False @@ -1730,11 +1733,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore: Union[List[int], int, None] = None, - amplitude_factor: Union[List[List[float]], List[float], float, None] = None, - parent_recording: Union[BaseRecording, None] = None, - num_samples: Optional[List[int]] = None, - upsample_vector: Union[List[int], None] = None, + nbefore: list[int] | int | None = None, + amplitude_factor: list[float] | float | None = None, + parent_recording: BaseRecording | None = None, + num_samples: list[int] | int | None = None, + upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) @@ -1866,10 +1869,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector: Union[List[float], None], - upsample_vector: Union[List[float], None], - parent_recording_segment: Union[BaseRecordingSegment, None] = None, - num_samples: Union[int, None] = None, + amplitude_vector: list[float] | None, + upsample_vector: list[float] | None, + parent_recording_segment: BaseRecordingSegment | None = None, + num_samples: int | None = None, ) -> None: BaseRecordingSegment.__init__( self, @@ -1889,9 +1892,9 @@ def __init__( def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] @@ -2070,13 +2073,13 @@ def generate_ground_truth_recording( Number of channels, not used when probe is given. num_units : int, default: 10 Number of units, not used when sorting is given. - sorting : Sorting or None + sorting : Sorting | None An external sorting object. If not provide, one is genrated. - probe : Probe or None + probe : Probe | None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. generate_probe_kwargs : dict A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. - templates : np.array or None + templates : np.array | None The templates of units. If None they are generated. Shape can be: @@ -2087,9 +2090,9 @@ def generate_ground_truth_recording( Cut out in ms before spike peak. ms_after : float, default: 3.0 Cut out in ms after spike peak. - upsample_factor : None or int, default: None + upsample_factor : None | int, default: None A upsampling factor used only when templates are not provided. - upsample_vector : np.array or None + upsample_vector : np.array | None Optional the upsample_vector can given. This has the same shape as spike_vector generate_sorting_kwargs : dict When sorting is not provide, this dict is used to generated a Sorting. @@ -2101,7 +2104,7 @@ def generate_ground_truth_recording( Dict used to generated template when template not provided. dtype : np.dtype, default: "float32" The dtype of the recording. - seed : int or None + seed : int | None Seed for random initialization. If None a diffrent Recording is generated at every call. Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..fa4547d272 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match -from .core_tools import check_json, retrieve_importing_provenance +from .core_tools import check_json, retrieve_importing_provenance, is_path_remote from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -195,6 +195,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, + storage_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -204,6 +205,7 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + self.storage_options = storage_options # this is used to store temporary recording self._temporary_recording = None @@ -276,17 +278,15 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto"): + def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. Otherwise the recording is loaded when possible. """ - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" if format == "auto": # make better assumption and check for auto guess format - if folder.suffix == ".zarr": + if Path(folder).suffix == ".zarr": format = "zarr" else: format = "binary_folder" @@ -294,12 +294,18 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_zarr( + folder, recording=recording, storage_options=storage_options + ) - sorting_analyzer.folder = folder + if is_path_remote(str(folder)): + sorting_analyzer.folder = folder + # in this case we only load extensions when needed + else: + sorting_analyzer.folder = Path(folder) - if load_extensions: - sorting_analyzer.load_all_saved_extension() + if load_extensions: + sorting_analyzer.load_all_saved_extension() return sorting_analyzer @@ -470,7 +476,9 @@ def load_from_binary_folder(cls, folder, recording=None): def _get_zarr_root(self, mode="r+"): import zarr - zarr_root = zarr.open(self.folder, mode=mode) + if is_path_remote(str(self.folder)): + mode = "r" + zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -552,25 +560,22 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") @classmethod - def load_from_zarr(cls, folder, recording=None): + def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - folder = Path(folder) - assert folder.is_dir(), f"This folder does not exist {folder}" - - zarr_root = zarr.open(folder, mode="r") + zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory - # TODO propagate storage_options sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True, copy_spike_vector=True + ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + with_metadata=True, + copy_spike_vector=True, ) # load recording if possible if recording is None: rec_dict = zarr_root["recording"][0] try: - recording = load_extractor(rec_dict, base_folder=folder) except: recording = None @@ -1209,11 +1214,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar print(f"Deleting {child}") self.delete_extension(child) - if extension_class.need_job_kwargs: - params, job_kwargs = split_job_kwargs(kwargs) - else: - params = kwargs - job_kwargs = {} + params, job_kwargs = split_job_kwargs(kwargs) # check dependencies if extension_class.need_recording: diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1b9637e097..17f1ac08b3 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -66,7 +66,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) time_kwargs = {} time_vector = self._root.get(f"times_seg{segment_index}", None) if time_vector is not None: - time_kwargs["time_vector"] = time_vector[:] + time_kwargs["time_vector"] = time_vector else: if t_starts is None: t_start = None diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index bb171fec0f..5f85538b08 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -92,13 +92,17 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo """ assert destination_format == "1" - + if "mergeGroups" not in sortingview_dict.keys(): + sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] merged_units = sum(merge_groups, []) - if len(merged_units) > 0: - unit_id_type = int if isinstance(merged_units[0], int) else str + + first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int else: unit_id_type = str + all_units = [] all_labels = [] manual_labels = [] diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json new file mode 100644 index 0000000000..2a350340f3 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json @@ -0,0 +1 @@ +{"labelsByUnit":{"2":["mua"],"3":["mua"],"4":["mua"],"5":["accept"],"6":["accept"],"7":["accept"],"8":["artifact"],"9":["artifact"]}} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index bb152e7f71..945aca7937 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -243,11 +243,23 @@ def test_label_inheritance_str(): assert np.all(sorting_include_accept.get_property("accept")) +def test_json_no_merge_curation(): + """ + Test curation with no merges using a JSON file. + """ + sorting = generate_sorting(num_units=10) + + json_file = parent_folder / "sv-sorting-curation-no-merge.json" + sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) + + if __name__ == "__main__": # generate_sortingview_curation_dataset() # test_sha1_curation() + test_gh_curation() test_json_curation() test_false_positive_curation() test_label_inheritance_int() test_label_inheritance_str() + test_json_no_merge_curation() diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index b3f671ebf3..cf47b9819c 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -18,7 +18,7 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): folder_path : str or Path-like The folder path to the AlphaOmega recordings. lsx_files : list of strings or None, default: None - A list of listings files that refers to mpx files to load. + A list of files that refers to mpx files to load. stream_id : {"RAW", "LFP", "SPK", "ACC", "AI", "UD"}, default: "RAW" If there are several streams, specify the stream id you want to load. stream_name : str, default: None @@ -28,6 +28,12 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_alphaomega + >>> recording = read_alphaomega(folder_path="alphaomega_folder") + """ NeoRawIOClass = "AlphaOmegaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index adfdccddd9..9de39bef2e 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -22,6 +22,11 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_axona + >>> recording = read_axona(file_path=r'my_data.set') """ NeoRawIOClass = "AxonaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index a42a2d75a5..992d1a8941 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -28,6 +28,11 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_ced + >>> recording = read_ced(file_path=r'my_data.smr') """ NeoRawIOClass = "CedRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index f0a1894f25..261472ede9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -34,7 +34,13 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): In Intan the ids provided by NeoRawIO are the hardware channel ids while the names are custom names given by the user - + Examples + -------- + >>> from spikeinterface.extractors import read_intan + # intan amplifier data is stored in stream_id = '0' + >>> recording = read_intan(file_path=r'my_data.rhd', stream_id='0') + # intan has multi-file formats as well, but in this case our path should point to the header file 'info.rhd' + >>> recording = read_intan(file_path=r'info.rhd', stream_id='0') """ NeoRawIOClass = "IntanRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index eed3188d16..412027bc06 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -30,6 +30,11 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Example for wideband signals: names: ["WB01", "WB02", "WB03", "WB04"] ids: ["0" , "1", "2", "3"] + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon + >>> recording = read_plexon(file_path=r'my_data.plx') """ NeoRawIOClass = "PlexonRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 4434d02cc1..2f360ed864 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,11 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon2 + >>> recording = read_plexon2(file_path=r'my_data.pl2') """ NeoRawIOClass = "Plexon2RawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 89c457a573..e91a81398b 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -29,6 +29,11 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikegadgets + >>> recording = read_spikegadgets(file_path=r'my_data.rec') """ NeoRawIOClass = "SpikeGadgetsRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index cfe20bbfa6..874a65c045 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -41,6 +41,13 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikeglx + >>> recording = read_spikeglx(folder_path=r'path_to_folder_with_data', load_sync_channel=False) + # we can load the sync channel, but then the probe is not loaded + >>> recording = read_spikeglx(folder_path=r'pat_to_folder_with_data', load_sync_channel=True) """ NeoRawIOClass = "SpikeGLXRawIO" diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 69f1fb6375..6ff8adadd2 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -350,9 +350,6 @@ def generate_drifting_recording( This can be helpfull for motion benchmark. """ - - rng = np.random.default_rng(seed=seed) - # probe if generate_probe_kwargs is None: generate_probe_kwargs = _toy_probes[probe_name] diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 54c5ab2b2d..a67d163d3d 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -24,10 +24,12 @@ class FilterRecording(BasePreprocessor): """ - Generic filter class based on: - - * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + A generic filter class based on: + For filter coefficient generation: + * scipy.signal.iirfilter + For filter application: + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward" + * scipy.signal.lfilter or scipy.signal.sosfilt when direction = "forward" or "backward" BandpassFilterRecording is built on top of it. @@ -56,6 +58,11 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward" | "forward-backward", default: "forward-backward" + Direction of filtering: + - "forward" - filter is applied to the timeseries in one direction, creating phase shifts + - "backward" - the timeseries is reversed, the filter is applied and filtered timeseries reversed again. Creates phase shifts in the opposite direction to "forward" + - "forward-backward" - Applies the filter in the forward and backward direction, resulting in zero-phase filtering. Note this doubles the effective filter order. Returns ------- @@ -75,6 +82,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + direction="forward-backward", ): import scipy.signal @@ -106,7 +114,13 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=add_reflect_padding, + direction=direction, ) ) @@ -121,14 +135,25 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__( + self, + parent_recording_segment, + coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=False, + direction="forward-backward", + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -150,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if self.direction == "forward-backward": + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + else: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -289,6 +327,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") + +def causal_filter( + recording, + direction="forward", + band=[300.0, 6000.0], + btype="bandpass", + filter_order=5, + ftype="butter", + filter_mode="sos", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, +): + """ + Generic causal filter built on top of the filter function. + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + direction : "forward" | "backward", default: "forward" + Direction of causal filter. The "backward" option flips the traces in time before applying the filter + and then flips them back. + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + + Returns + ------- + filter_recording : FilterRecording + The causal-filtered recording extractor object + """ + assert direction in ["forward", "backward"], "Direction must be either 'forward' or 'backward'" + return filter( + recording=recording, + direction=direction, + band=band, + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + margin_ms=margin_ms, + add_reflect_padding=add_reflect_padding, + coeff=coeff, + dtype=dtype, + ) + + bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 149c6eb458..bdf5f2219c 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -12,6 +12,7 @@ notch_filter, HighpassFilterRecording, highpass_filter, + causal_filter, ) from .filter_gaussian import GaussianFilterRecording, gaussian_filter from .normalize_scale import ( diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 68790b3273..9df60af3db 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -4,7 +4,140 @@ from spikeinterface.core import generate_recording from spikeinterface import NumpyRecording, set_global_tmp_folder -from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter +from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter, causal_filter + + +class TestCausalFilter: + """ + The only thing that is not tested (JZ, as of 23/07/2024) is the + propagation of margin kwargs, these are general filter params + and can be tested in an upcoming PR. + """ + + @pytest.fixture(scope="session") + def recording_and_data(self): + recording = generate_recording(durations=[1]) + raw_data = recording.get_traces() + + return (recording, raw_data) + + def test_causal_filter_main_kwargs(self, recording_and_data): + """ + Perform a test that expected output is returned under change + of all key filter-related kwargs. First run the filter in + the forward direction with options and compare it + to the expected output from scipy. + + Next, change every filter-related kwarg and set in the backwards + direction. Again check it matches expected scipy output. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + # First, check in the forward direction with + # the default set of kwargs + options = self._get_filter_options() + + sos = self._run_iirfilter(options, recording) + + test_data = sosfilt(sos, raw_data, axis=0) + test_data.astype(recording.dtype) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + # Then, change all kwargs to ensure they are propagated + # and check the backwards version. + options["band"] = [671] + options["btype"] = "highpass" + options["filter_order"] = 8 + options["ftype"] = "bessel" + options["filter_mode"] = "ba" + options["dtype"] = np.float16 + + b, a = self._run_iirfilter(options, recording) + + flip_raw = np.flip(raw_data, axis=0) + test_data = lfilter(b, a, flip_raw, axis=0) + test_data = np.flip(test_data, axis=0) + test_data = test_data.astype(options["dtype"]) + + filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + def test_causal_filter_custom_coeff(self, recording_and_data): + """ + A different path is taken when custom coeff is selected. + Therefore, explicitly test the expected outputs are obtained + when passing custom coeff, under the "ba" and "sos" conditions. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + options = self._get_filter_options() + options["filter_mode"] = "ba" + options["coeff"] = (np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])) + + # Check the custom coeff are propagated in both modes. + # First, in "ba" mode + test_data = lfilter(options["coeff"][0], options["coeff"][1], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + # Next, in "sos" mode + options["filter_mode"] = "sos" + options["coeff"] = np.ones((2, 6)) + + test_data = sosfilt(options["coeff"], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + def test_causal_kwarg_error_raised(self, recording_and_data): + """ + Test that passing the "forward-backward" direction results in + an error. It is is critical this error is raised, + otherwise the filter will no longer be causal. + """ + recording, raw_data = recording_and_data + + with pytest.raises(BaseException) as e: + filt_data = causal_filter(recording, direction="forward-backward") + + def _run_iirfilter(self, options, recording): + """ + Convenience function to convert Si kwarg + names to Scipy. + """ + from scipy.signal import iirfilter + + return iirfilter( + N=options["filter_order"], + Wn=options["band"], + btype=options["btype"], + ftype=options["ftype"], + output=options["filter_mode"], + fs=recording.get_sampling_frequency(), + ) + + def _get_filter_options(self): + return { + "band": [300.0, 6000.0], + "btype": "bandpass", + "filter_order": 5, + "ftype": "butter", + "filter_mode": "sos", + "coeff": None, + } def test_filter(): @@ -28,6 +161,8 @@ def test_filter(): # other filtering types rec3 = filter(rec, band=500.0, btype="highpass", filter_mode="ba", filter_order=2) rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.0) + rec5 = causal_filter(rec, direction="forward") + rec6 = causal_filter(rec, direction="backward") # filter from coefficients from scipy.signal import iirfilter diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 6406919455..f9611586c9 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -99,7 +99,7 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): singularity_image = sif_file else: - docker_image = self._get_docker_image(container_image) + docker_image = Client.load("docker://" + container_image) if docker_image and len(docker_image.tags) > 0: tag = docker_image.tags[0] print(f"Building singularity image from local docker image: {tag}") diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..8499cef11f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,6 +6,7 @@ from ..basesorter import BaseSorter from .kilosortbase import KilosortBase +from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -56,6 +57,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, "scaleproc": None, + "save_preprocessed_copy": False, "torch_device": "auto", } @@ -93,11 +95,11 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "keep_good_only": "If True only 'good' units are returned", "do_correction": "If True, drift correction is performed", "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", } @@ -129,9 +131,8 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - import kilosort as ks - - return ks.__version__ + """kilosort.__version__ <4.0.10 is always '4'""" + return importlib_version("kilosort") @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @@ -153,7 +154,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_sorting, get_run_parameters, ) - from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing from kilosort.parameters import DEFAULT_SETTINGS import time @@ -165,6 +166,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): + raise RuntimeError( + "Kilosort versions before 4.0.5 are not supported" + "in SpikeInterface. " + "Please upgrade Kilosort version." + ) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -176,16 +184,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - probe = load_probe(probe_filename) + probe = load_probe(probe_path=probe_filename) probe_name = "" filename = "" # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording) + file_object = RecordingExtractorAsArray(recording_extractor=recording) do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] + save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -205,31 +214,58 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + + filename, data_dir, results_dir, probe = set_files( + settings=settings, + filename=filename, + probe=probe, + probe_name=probe_name, + data_dir=data_dir, + results_dir=results_dir, + ) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + ) + else: + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + ) + + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) else: - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( get_run_parameters(ops) ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, + filename=ops["filename"], + n_chan_bin=n_chan_bin, + fs=fs, + NT=NT, + nt=nt, + nt0min=twav_min, + chan_map=chan_map, hp_filter=None, device=device, do_CAR=do_CAR, @@ -243,29 +279,51 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches + # bfile.close() # TODO: KS do this after preprocessing? np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) - # if not params["skip_kilosort_preprocessing"]: + if not params["do_correction"]: print("Skipping drift correction.") ops["nblocks"] = 0 # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) + if save_preprocessed_copy: + save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) + # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + + clu, Wall = cluster_spikes( + st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + ) + if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + else: + _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) @classmethod def _get_result_from_folder(cls, sorter_output_folder):