Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into sortingview-curation-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rkim48 authored Aug 20, 2024
2 parents d029f7d + 694f862 commit b6316b5
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Kilosort4Sorter(BaseSorter):
"save_extra_kwargs": False,
"skip_kilosort_preprocessing": False,
"scaleproc": None,
"save_preprocessed_copy": False,
"torch_device": "auto",
}

Expand Down Expand Up @@ -98,6 +99,7 @@ class Kilosort4Sorter(BaseSorter):
"save_extra_kwargs": "If True, additional kwargs are saved to the output",
"skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data",
"torch_device": "Select the torch device auto/cuda/cpu",
}

Expand Down Expand Up @@ -153,7 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
save_sorting,
get_run_parameters,
)
from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered
from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing
from kilosort.parameters import DEFAULT_SETTINGS

import time
Expand Down Expand Up @@ -186,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
do_CAR = params["do_CAR"]
invert_sign = params["invert_sign"]
save_extra_vars = params["save_extra_kwargs"]
save_preprocessed_copy = params["save_preprocessed_copy"]
progress_bar = None
settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS}
settings_ks["n_chan_bin"] = recording.get_num_channels()
Expand All @@ -207,7 +210,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
results_dir = sorter_output_folder
filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir)
if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"):
ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False)
ops = initialize_ops(
settings,
probe,
recording.get_dtype(),
do_CAR,
invert_sign,
device,
save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo)
)
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
)
Expand Down Expand Up @@ -257,6 +268,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object
)

if save_preprocessed_copy:
save_preprocessing(results_dir / "temp_wh.dat", ops, bfile)

# Sort spikes and save results
st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
Expand All @@ -265,7 +279,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels()))
)

_ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars)
if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"):
_ = save_sorting(
ops,
results_dir,
st,
clu,
tF,
Wall,
bfile.imin,
tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy,
)
else:
_ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars)

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
Expand Down

0 comments on commit b6316b5

Please sign in to comment.