diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..6d83249653 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -56,6 +56,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, "scaleproc": None, + "save_preprocessed_copy": False, "torch_device": "auto", } @@ -98,6 +99,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", } @@ -153,7 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_sorting, get_run_parameters, ) - from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing from kilosort.parameters import DEFAULT_SETTINGS import time @@ -186,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] + save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -207,7 +210,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings, + probe, + recording.get_dtype(), + do_CAR, + invert_sign, + device, + save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) @@ -257,6 +268,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) + if save_preprocessed_copy: + save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) + # Sort spikes and save results st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) @@ -265,7 +279,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + _ = save_sorting( + ops, + results_dir, + st, + clu, + tF, + Wall, + bfile.imin, + tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + else: + _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) @classmethod def _get_result_from_folder(cls, sorter_output_folder):