Skip to content

Commit

Permalink
Merge pull request #3055 from alejoe91/update-ks4
Browse files Browse the repository at this point in the history
Add support for kilosort>=4.0.12
  • Loading branch information
alejoe91 authored Jun 21, 2024
2 parents 328bba0 + 3a65457 commit 5e32a9f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,8 @@ def get_traces(
) -> np.ndarray:
start_frame = 0 if start_frame is None else max(start_frame, 0)
end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples)
start_frame = int(start_frame)
end_frame = int(end_frame)

start_frame_within_block = start_frame % self.noise_block_size
end_frame_within_block = end_frame % self.noise_block_size
Expand Down Expand Up @@ -1812,6 +1814,8 @@ def get_traces(
) -> np.ndarray:
start_frame = 0 if start_frame is None else start_frame
end_frame = self.num_samples if end_frame is None else end_frame
start_frame = int(start_frame)
end_frame = int(end_frame)

if channel_indices is None:
n_channels = self.templates.shape[2]
Expand Down Expand Up @@ -1848,6 +1852,8 @@ def get_traces(
end_traces = start_traces + template.shape[0]
if start_traces >= end_frame - start_frame or end_traces <= 0:
continue
start_traces = int(start_traces)
end_traces = int(end_traces)

start_template = 0
end_template = template.shape[0]
Expand Down
26 changes: 22 additions & 4 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pathlib import Path
from typing import Union
from packaging import version

from ..basesorter import BaseSorter
from .kilosortbase import KilosortBase
Expand All @@ -24,11 +25,14 @@ class Kilosort4Sorter(BaseSorter):
"do_CAR": True,
"invert_sign": False,
"nt": 61,
"shift": None,
"scale": None,
"artifact_threshold": None,
"nskip": 25,
"whitening_range": 32,
"binning_depth": 5,
"sig_interp": 20,
"drift_smoothing": [0.5, 0.5, 0.5],
"nt0min": None,
"dmin": None,
"dminx": 32,
Expand Down Expand Up @@ -63,11 +67,14 @@ class Kilosort4Sorter(BaseSorter):
"do_CAR": "Whether to perform common average reference. Default value: True.",
"invert_sign": "Invert the sign of the data. Default value: False.",
"nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.",
"shift": "Scalar shift to apply to data before all other operations. Default None.",
"scale": "Scaling factor to apply to data before all other operations. Default None.",
"artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.",
"nskip": "Batch stride for computing whitening matrix. Default value: 25.",
"whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.",
"binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.",
"sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.",
"drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.",
"nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.",
"dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.",
"dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.",
Expand Down Expand Up @@ -153,6 +160,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
import torch
import numpy as np

if verbose:
import logging

logging.basicConfig(level=logging.INFO)

sorter_output_folder = sorter_output_folder.absolute()

probe_filename = sorter_output_folder / "probe.prb"
Expand Down Expand Up @@ -194,11 +206,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
data_dir = ""
results_dir = sorter_output_folder
filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir)
ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device)
if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"):
ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False)
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)
else:
ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device)
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = (
get_run_parameters(ops)
)

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = (
get_run_parameters(ops)
)
# Set preprocessing and drift correction parameters
if not params["skip_kilosort_preprocessing"]:
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
Expand Down

0 comments on commit 5e32a9f

Please sign in to comment.