From a5cb02f64cded1a94ac17a4347b5321b839da210 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Apr 2024 13:24:09 +0200 Subject: [PATCH 1/2] Prepare release 0.100.6 --- doc/releases/0.100.6.rst | 12 +++ doc/whatisnew.rst | 1 + pyproject.toml | 2 +- .../sorters/external/kilosort4.py | 81 ++++--------------- 4 files changed, 29 insertions(+), 67 deletions(-) create mode 100644 doc/releases/0.100.6.rst diff --git a/doc/releases/0.100.6.rst b/doc/releases/0.100.6.rst new file mode 100644 index 0000000000..99f34f534f --- /dev/null +++ b/doc/releases/0.100.6.rst @@ -0,0 +1,12 @@ +.. _release0.100.6: + +SpikeInterface 0.100.6 release notes +------------------------------------ + +30th April 2024 + +Minor release with bug fixes + +* Improve caching of MS5 sorter (#2690) +* Allow for remove_excess_spikes to remove negative spike times (#2716) +* Update ks4 wrapper for newer version>=4.0.3 (#2701, #2774) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 015d033385..48f566087f 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.6.rst releases/0.100.5.rst releases/0.100.4.rst releases/0.100.3.rst diff --git a/pyproject.toml b/pyproject.toml index e17780890f..c9fc3ecf78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.100.5" +version = "0.100.6" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 6ff836b753..90bdc1056d 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -1,8 +1,8 @@ from __future__ import annotations from pathlib import Path -import os from typing import Union +from packaging.version import parse from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -51,10 +51,11 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, "scaleproc": None, + "torch_device": "auto", } _params_description = { - "batch_size": "Number of samples per batch. Default value: 60000.", + "batch_size": "Number of samples included in each batch of data.", "nblocks": "Number of non-overlapping blocks for drift correction (additional nblocks-1 blocks are created in the overlaps). Default value: 1.", "Th_universal": "Spike detection threshold for universal templates. Th(1) in previous versions of Kilosort. Default value: 9.", "Th_learned": "Spike detection threshold for learned templates. Th(2) in previous versions of Kilosort. Default value: 8.", @@ -87,6 +88,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "torch_device": "Select the torch device auto/cuda/cpu", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -152,7 +154,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch_device = params["torch_device"] + if torch_device == "auto": + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(torch_device) # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -222,39 +227,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) # if not params["skip_kilosort_preprocessing"]: - if params["do_correction"]: - # this function applies both preprocessing and drift correction - ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object - ) - else: + if not params["do_correction"]: print("Skipping drift correction.") - hp_filter = ops["preprocessing"]["hp_filter"] - whiten_mat = ops["preprocessing"]["whiten_mat"] - - bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, - hp_filter=hp_filter, - whiten_mat=whiten_mat, - device=device, - do_CAR=do_CAR, - invert_sign=invert, - dtype=dtype, - tmin=tmin, - tmax=tmax, - artifact_threshold=artifact, - file_object=file_object, - ) + ops["nblocks"] = 0 - # TODO: don't think we need to do this actually - # Save intermediate `ops` for use by GUI plots - # io.save_ops(ops, results_dir) + # this function applies both preprocessing and drift correction + ops, bfile, st0 = compute_drift_correction( + ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ) # Sort spikes and save results st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) @@ -263,39 +243,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - ops, similar_templates, is_ref, est_contam_rate = save_sorting( - ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars - ) - - # # Clean-up temporary files - # if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists(): - # recording_file.unlink() - - # all_tmp_files = ("matlab_files", "temp_wh.dat") - - # if isinstance(params["delete_tmp_files"], bool): - # if params["delete_tmp_files"]: - # tmp_files_to_remove = all_tmp_files - # else: - # tmp_files_to_remove = () - # else: - # assert isinstance( - # params["delete_tmp_files"], (tuple, list) - # ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`." - - # for name in params["delete_tmp_files"]: - # assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}" - - # tmp_files_to_remove = params["delete_tmp_files"] - - # if "temp_wh.dat" in tmp_files_to_remove: - # if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists(): - # temp_wh_file.unlink() - # if "matlab_files" in tmp_files_to_remove: - # for ext in ["*.m", "*.mat"]: - # for temp_file in sorter_output_folder.glob(ext): - # temp_file.unlink() + _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From d8b7fdbaf1e80e9c09fde3e9afa7e130481618d9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 30 Apr 2024 14:30:29 +0200 Subject: [PATCH 2/2] Propagate #2621 --- doc/releases/0.100.6.rst | 1 + doc/whatisnew.rst | 7 +++++++ src/spikeinterface/core/core_tools.py | 4 +++- src/spikeinterface/core/waveform_tools.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/doc/releases/0.100.6.rst b/doc/releases/0.100.6.rst index 99f34f534f..7f2bb5cd66 100644 --- a/doc/releases/0.100.6.rst +++ b/doc/releases/0.100.6.rst @@ -7,6 +7,7 @@ SpikeInterface 0.100.6 release notes Minor release with bug fixes +* Avoid np.prod in make_shared_array (#2621) * Improve caching of MS5 sorter (#2690) * Allow for remove_excess_spikes to remove negative spike times (#2716) * Update ks4 wrapper for newer version>=4.0.3 (#2701, #2774) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 48f566087f..2ba199eb94 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -40,11 +40,18 @@ Release notes releases/0.9.1.rst +Version 0.100.6 +=============== + +* Minor release with bug fixes + + Version 0.100.5 =============== * Minor release with bug fixes + Version 0.100.4 =============== diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3b82436d5c..3725fcfba8 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -7,6 +7,7 @@ import json from copy import deepcopy +from math import prod import numpy as np from tqdm import tqdm @@ -163,7 +164,8 @@ def make_shared_array(shape, dtype): from multiprocessing.shared_memory import SharedMemory dtype = np.dtype(dtype) - nbytes = int(np.prod(shape) * dtype.itemsize) + shape = tuple(int(x) for x in shape) # We need to be sure that shape comes in int instead of numpy scalars + nbytes = prod(shape) * dtype.itemsize shm = SharedMemory(name=None, create=True, size=nbytes) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) arr[:] = 0 diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index f9e39382df..8864ae0d39 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -483,7 +483,7 @@ def extract_waveforms_to_single_buffer( if sparsity_mask is None: num_chans = recording.get_num_channels() else: - num_chans = max(np.sum(sparsity_mask, axis=1)) + num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int shape = (num_spikes, nsamples, num_chans) if mode == "memmap":