diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index dbd8135b9a..61c10fd8e8 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -415,27 +415,46 @@ def test_use_binary_file(self, tmp_path): sorting_ks4 = si.run_sorter( "kilosort4", recording, - folder=tmp_path / "spikeinterface_output_dir_wrapper", - use_binary_file=False, + folder=tmp_path / "ks4_output_si_wrapper_default", + use_binary_file=None, remove_existing_folder=True, ) sorting_ks4_bin = si.run_sorter( "kilosort4", recording_bin, - folder=tmp_path / "spikeinterface_output_dir_bin", + folder=tmp_path / "ks4_output_bin_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_force_binary = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin", + use_binary_file=True, + remove_existing_folder=True, + ) + assert not (tmp_path / "ks4_output_force_bin" / "sorter_output" / "recording.dat").exists() + sorting_ks4_force_non_binary = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_force_wrapper", use_binary_file=False, remove_existing_folder=True, ) - sorting_ks4_non_bin = si.run_sorter( + # test deleting recording.dat + sorting_ks4_force_binary_keep = si.run_sorter( "kilosort4", recording, - folder=tmp_path / "spikeinterface_output_dir_non_bin", + folder=tmp_path / "ks4_output_force_bin_keep", use_binary_file=True, + delete_recording_dat=False, remove_existing_folder=True, ) + assert (tmp_path / "ks4_output_force_bin_keep" / "sorter_output" / "recording.dat").exists() check_sortings_equal(sorting_ks4, sorting_ks4_bin) - check_sortings_equal(sorting_ks4, sorting_ks4_non_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_force_binary) + check_sortings_equal(sorting_ks4, sorting_ks4_force_non_binary) @pytest.mark.parametrize( "param_to_test", diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index b0ba054e2d..8a15642af4 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -65,7 +65,8 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, - "use_binary_file": False, + "use_binary_file": None, + "delete_recording_dat": True, } _params_description = { @@ -110,8 +111,10 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", - "use_binary_file": "If True and the recording is not binary compatible, then Kilosort is written to a binary file in the output folder. If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. " - "If the recording is binary compatible, then the sorter will always use the binary file. Default is False.", + "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binaru compatible, it is written to a binary file in the output folder. " + "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " + "Default is None.", + "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -172,15 +175,16 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" write_prb(probe_filename, pg) - if params["use_binary_file"] and not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): - # local copy needed - binary_file_path = sorter_output_folder / "recording.dat" - write_binary_recording( - recording=recording, - file_paths=[binary_file_path], - **get_job_kwargs(params, verbose), - ) - params["filename"] = str(binary_file_path) + if params["use_binary_file"]: + if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # local copy needed + binary_file_path = sorter_output_folder / "recording.dat" + write_binary_recording( + recording=recording, + file_paths=[binary_file_path], + **get_job_kwargs(params, verbose), + ) + params["filename"] = str(binary_file_path) @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -227,18 +231,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe = load_probe(probe_path=probe_filename) probe_name = "" - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): - # no copy - binary_description = recording.get_binary_description() - filename = str(binary_description["file_paths"][0]) - file_object = None + if params["use_binary_file"] is None: + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # the recording is not binary compatible and no binary copy has been written. + # in this case, we use the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) elif params["use_binary_file"]: - # a local copy has been written - filename = str(sorter_output_folder / "recording.dat") - file_object = None + # here we force the use of a binary file + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None else: - # the recording is not binary compatible and no binary copy has been written. - # in this case, we use the RecordingExtractorAsArray object + # here we force the use of the RecordingExtractorAsArray object filename = "" file_object = RecordingExtractorAsArray(recording_extractor=recording) @@ -362,6 +378,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_preprocessed_copy=save_preprocessed_copy, ) + if params["delete_recording_dat"]: + # only delete dat file if it was created by the wrapper + if (sorter_output_folder / "recording.dat").is_file(): + (sorter_output_folder / "recording.dat").unlink() + @classmethod def _get_result_from_folder(cls, sorter_output_folder): return KilosortBase._get_result_from_folder(sorter_output_folder)