From 0c88b39a875e5068b0bfd4f63db7ff45f025e202 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 29 Nov 2024 05:46:02 +0100 Subject: [PATCH] Patch to force remove sorters --- src/spikeinterface/benchmark/benchmark_base.py | 5 +++-- src/spikeinterface/benchmark/benchmark_sorter.py | 9 +++++++++ src/spikeinterface/curation/auto_merge.py | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index b9cbf269c8..ddcf25f2ab 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in case_keys: result_folder = self.folder / "results" / self.key_to_str(key) - + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + if keep and result_folder.exists(): continue - elif not keep and result_folder.exists(): + elif not keep and (result_folder.exists() or sorter_folder.exists()): self.remove_benchmark(key) job_keys.append(key) diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index f9267c785a..8180c943be 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -56,6 +56,15 @@ def create_benchmark(self, key): benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) return benchmark + def remove_benchmark(self, key): + BenchmarkStudy.remove_benchmark(self, key) + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + import shutil + if sorter_folder.exists(): + shutil.rmtree(sorter_folder) + + def get_performance_by_unit(self, case_keys=None): import pandas as pd diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 4f4cff144e..89c24565c2 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -231,7 +231,7 @@ def compute_merge_unit_groups( params = _default_step_params.get(step).copy() if steps_params is not None and step in steps_params: params.update(steps_params[step]) - + # STEP : remove units with too few spikes if step == "num_spikes":