Skip to content

Commit

Permalink
Add highpass_cutoff and fix KS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 3, 2024
1 parent c23d530 commit f9dfa04
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
29 changes: 18 additions & 11 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@

import copy
from typing import Any
import spikeinterface.full as si
import numpy as np
import torch
import kilosort
from kilosort.io import load_probe
import pandas as pd
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
import pytest
from probeinterface.io import write_prb
from kilosort.parameters import DEFAULT_SETTINGS
from packaging.version import parse
from importlib.metadata import version
from inspect import signature

import spikeinterface.full as si
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
from probeinterface.io import write_prb

from kilosort.parameters import DEFAULT_SETTINGS
from kilosort.run_kilosort import (
set_files,
initialize_ops,
Expand Down Expand Up @@ -66,6 +68,7 @@
("nt", 93),
("nskip", 1),
("whitening_range", 16),
("highpass_cutoff", 200),
("sig_interp", 5),
("nt0min", 25),
("dmin", 15),
Expand All @@ -87,10 +90,11 @@
("ccg_threshold", 1e12),
("acg_threshold", 1e12),
("cluster_downsampling", 2),
("duplicate_spike_bins", 5),
("duplicate_spike_ms", 0.3),
("drift_smoothing", [250, 250, 250]),
("bad_channels", None),
("save_preprocessed_copy", False),
("shift", 0),
("scale", 1),
]


Expand Down Expand Up @@ -194,7 +198,10 @@ def test_params_to_test(self):
continue

if param_key not in RUN_KILOSORT_ARGS:
assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test."
assert DEFAULT_SETTINGS[param_key] != param_value, (
f"{param_key} values should be different in test: "
f"{param_value} vs. {DEFAULT_SETTINGS[param_key]}"
)

def test_default_settings_all_represented(self):
"""
Expand Down Expand Up @@ -227,7 +234,7 @@ def test_spikeinterface_defaults_against_kilsort(self):

# Testing Arguments ###
def test_set_files_arguments(self):
self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"])
self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"])

def test_initialize_ops_arguments(self):
expected_arguments = [
Expand All @@ -249,13 +256,13 @@ def test_compute_preprocessing_arguments(self):
self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"])

def test_compute_drift_location_arguments(self):
self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object"])
self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"])

def test_detect_spikes_arguments(self):
self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar"])
self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])

def test_cluster_spikes_arguments(self):
self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"])
self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])

def test_save_sorting_arguments(self):
expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"]
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Kilosort4Sorter(BaseSorter):
"artifact_threshold": None,
"nskip": 25,
"whitening_range": 32,
"highpass_cutoff": 300,
"binning_depth": 5,
"sig_interp": 20,
"drift_smoothing": [0.5, 0.5, 0.5],
Expand All @@ -55,7 +56,7 @@ class Kilosort4Sorter(BaseSorter):
"cluster_downsampling": 20,
"cluster_pcs": 64,
"x_centers": None,
"duplicate_spike_bins": 7,
"duplicate_spike_ms": 0.25,
"do_correction": True,
"keep_good_only": False,
"save_extra_kwargs": False,
Expand All @@ -80,6 +81,7 @@ class Kilosort4Sorter(BaseSorter):
"artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.",
"nskip": "Batch stride for computing whitening matrix. Default value: 25.",
"whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.",
"highpass_cutoff": "High-pass filter cutoff frequency in Hz. Default value: 300.",
"binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.",
"sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.",
"drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.",
Expand Down

0 comments on commit f9dfa04

Please sign in to comment.