From 22b90945c82e55c15e06c9c92ebd6b752889906a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 30 Oct 2024 16:56:22 +0100 Subject: [PATCH] avoid copy when not necessary --- src/spikeinterface/curation/auto_merge.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index eeeb5b2098..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -195,7 +195,19 @@ def compute_merge_unit_groups( raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") steps = _compute_merge_presets[preset] - if force_copy: + # check at least one extension is needed + at_least_one_extension_to_compute = False + for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: + raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -205,14 +217,10 @@ def compute_merge_unit_groups( for step in steps: - assert step in _default_step_params, f"{step} is not a valid step" - if step in _required_extensions: for ext in _required_extensions[step]: if sorting_analyzer.has_extension(ext): continue - if not compute_needed_extensions: - raise ValueError(f"{step} requires {ext} extension") # special case for templates if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"):