From 1b66a4980a141236c7e6c9552968e0bdb64fa5d9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jul 2024 20:06:06 +0200 Subject: [PATCH 01/18] Smal fix or backward compatibuility --- src/spikeinterface/core/sortinganalyzer.py | 9 +++++++++ src/spikeinterface/postprocessing/template_similarity.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 27a47a31ac..84e8043a04 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1699,6 +1699,7 @@ class AnalyzerExtension: use_nodepipeline = False nodepipeline_variables = None need_job_kwargs = False + need_backward_compatibility_on_load = False def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) @@ -1737,6 +1738,11 @@ def _get_data(self): # must be implemented in subclass raise NotImplementedError + def _handle_backward_compatibility_on_load(self): + # must be implemented in subclass only if need_backward_compatibility_on_load=True + raise NotImplementedError + + @classmethod def function_factory(cls): # make equivalent @@ -1814,6 +1820,9 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_data() + if cls.need_backward_compatibility_on_load: + self._handle_backward_compatibility_on_load() + return ext def load_params(self): diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index a9592b0b91..0481cdfaca 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -46,6 +46,12 @@ class ComputeTemplateSimilarity(AnalyzerExtension): def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) + + def _handle_backward_compatibility_on_load(self): + if "max_lag_ms" not in self.params: + # make compatible analyzer created between february 24 and july 24 + self.params["max_lag_ms"] = 0. + self.params["support"] = "union" def _set_params(self, method="cosine", max_lag_ms=0, support="union"): if method == "cosine_similarity": From c97802ac4787a7f9f91123e832a48f5db2f7f9a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 18:07:15 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 1 - src/spikeinterface/postprocessing/template_similarity.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 84e8043a04..e6f55a2291 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1742,7 +1742,6 @@ def _handle_backward_compatibility_on_load(self): # must be implemented in subclass only if need_backward_compatibility_on_load=True raise NotImplementedError - @classmethod def function_factory(cls): # make equivalent diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0481cdfaca..67c036cb28 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -46,11 +46,11 @@ class ComputeTemplateSimilarity(AnalyzerExtension): def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - + def _handle_backward_compatibility_on_load(self): if "max_lag_ms" not in self.params: # make compatible analyzer created between february 24 and july 24 - self.params["max_lag_ms"] = 0. + self.params["max_lag_ms"] = 0.0 self.params["support"] = "union" def _set_params(self, method="cosine", max_lag_ms=0, support="union"): From 841796c8f4db97ab99b2563c863d0871e1dd9d40 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jul 2024 20:42:49 +0200 Subject: [PATCH 03/18] oups --- src/spikeinterface/postprocessing/template_similarity.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0481cdfaca..f86cd4ea5e 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -43,6 +43,7 @@ class ComputeTemplateSimilarity(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) From 7253e4fd7dcf47474bd8a9f1da950864df94bb68 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jul 2024 20:53:33 +0200 Subject: [PATCH 04/18] Handle dtype in merges for unit_ids --- src/spikeinterface/core/sorting_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2a2f7b6b5a..942d0ca53a 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -347,6 +347,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): """ old_unit_ids = np.asarray(old_unit_ids) + dtype = old_unit_ids.dtype assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" @@ -361,7 +362,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): all_unit_ids.remove(unit_id) if new_unit_id not in all_unit_ids: all_unit_ids.append(new_unit_id) - return np.array(all_unit_ids) + return np.array(all_unit_ids, dtype=dtype) def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"): From 0b29f705c5ced6755bd5b3ada42647538b458e74 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jul 2024 07:31:22 +0200 Subject: [PATCH 05/18] oups --- src/spikeinterface/core/sortinganalyzer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e6f55a2291..8c5159c865 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1398,9 +1398,7 @@ def load_extension(self, extension_name: str): extension_class = get_extension_class(extension_name) - extension_instance = extension_class(self) - extension_instance.load_params() - extension_instance.load_data() + extension_instance = extension_class.load(self) self.extensions[extension_name] = extension_instance @@ -1820,7 +1818,7 @@ def load(cls, sorting_analyzer): ext.load_params() ext.load_data() if cls.need_backward_compatibility_on_load: - self._handle_backward_compatibility_on_load() + ext._handle_backward_compatibility_on_load() return ext From 1e026cca53433a8a733c4509cd6373af4d04105f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jul 2024 08:28:14 +0200 Subject: [PATCH 06/18] fix dtype when merging --- src/spikeinterface/core/sorting_tools.py | 3 +++ src/spikeinterface/core/tests/test_sorting_tools.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 942d0ca53a..336c3711aa 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -348,6 +348,9 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): """ old_unit_ids = np.asarray(old_unit_ids) dtype = old_unit_ids.dtype + if dtype.kind == 'U': + # the new dtype can be longer + dtype = 'U' assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 38baf62c35..6d0e61f844 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -161,4 +161,4 @@ def test_generate_unit_ids_for_merge_group(): test_apply_merges_to_sorting() test_get_ids_after_merging() - test_generate_unit_ids_for_merge_group() + # test_generate_unit_ids_for_merge_group() From 32844694afbe4d6c5de62535ce1bb0d98cf62d26 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 06:30:54 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 336c3711aa..6994575150 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -348,9 +348,9 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): """ old_unit_ids = np.asarray(old_unit_ids) dtype = old_unit_ids.dtype - if dtype.kind == 'U': + if dtype.kind == "U": # the new dtype can be longer - dtype = 'U' + dtype = "U" assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" From 3cd4d03f3dac793cc1419f4ae39155ca943f3904 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jul 2024 12:55:49 +0200 Subject: [PATCH 08/18] Use also _handle_backward_compatibility_on_load for ComputeTemplates --- .../core/analyzer_extension_core.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index ad23a5f249..fe67da22be 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -364,6 +364,17 @@ class ComputeTemplates(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + if "ms_before" not in self.params: + # compatibility february 2024 > july 2024 + self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency + + if "ms_after" not in self.params: + # compatibility february 2024 > july 2024 + self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency + def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None): operators = operators or ["average", "std"] @@ -487,31 +498,11 @@ def _compute_and_append_from_waveforms(self, operators): @property def nbefore(self): - if "ms_before" not in self.params: - # compatibility february 2024 > july 2024 - self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency - warnings.warn( - "The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params." - "You can save the sorting_analyzer to update the params.", - DeprecationWarning, - stacklevel=2, - ) - nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) return nbefore @property def nafter(self): - if "ms_after" not in self.params: - # compatibility february 2024 > july 2024 - warnings.warn( - "The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params." - "You can save the sorting_analyzer to update the params.", - DeprecationWarning, - stacklevel=2, - ) - self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency - nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) return nafter From f457ba7a5a29593cfa084749735de6aea2e8b637 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:58:26 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/analyzer_extension_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fe67da22be..ff1dc5dafa 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -375,7 +375,6 @@ def _handle_backward_compatibility_on_load(self): # compatibility february 2024 > july 2024 self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency - def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=None): operators = operators or ["average", "std"] assert isinstance(operators, list) From 8b1079bf4f77a711da811209921e8f122fea6716 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Wed, 17 Jul 2024 11:41:37 +0000 Subject: [PATCH 10/18] waveforms backward compatibility, unit locations, and quality metrics recomputation --- src/spikeinterface/core/sortinganalyzer.py | 4 ++++ .../waveforms_extractor_backwards_compatibility.py | 11 +++++------ src/spikeinterface/postprocessing/unit_locations.py | 7 +++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8c5159c865..1200f912fa 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -777,6 +777,10 @@ def _save_or_select_or_merge( # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing sorted_extensions = _sort_extensions_by_dependency(self.extensions) + # hack: quality metrics are computed at last + qm_extension_params = sorted_extensions.pop("quality_metrics") + if qm_extension_params is not None: + sorted_extensions["quality_metrics"] = qm_extension_params recompute_dict = {} for extension_name, extension in sorted_extensions.items(): diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d6d60ee73b..dc983c6b18 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -547,7 +547,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): # update params new_params = ext._set_params() updated_params = make_ext_params_up_to_date(ext, params, new_params) - ext.set_params(**updated_params) + ext.set_params(**updated_params, save=False) if new_name == "spike_amplitudes": amplitudes = [] @@ -614,13 +614,12 @@ def make_ext_params_up_to_date(ext, old_params, new_params): old_name = ext.extension_name updated_params = old_params.copy() for p, values in old_params.items(): - if isinstance(values, dict): + if p not in new_params: + warnings.warn(f"Removing legacy param {p} from {old_name} extension") + updated_params.pop(p) + elif isinstance(values, dict): new_values = new_params.get(p, {}) updated_params[p] = make_ext_params_up_to_date(ext, values, new_values) - else: - if p not in new_params: - warnings.warn(f"Removing legacy param {p} from {old_name} extension") - updated_params.pop(p) return updated_params diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 516f22e31e..0aec6e155b 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -42,10 +42,17 @@ class ComputeUnitLocations(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) + def _handle_backward_compatibility_on_load(self): + if "method_kwargs" in self.params: + # make compatible analyzer created between february 24 and july 24 + method_kwargs = self.params.pop("method_kwargs") + self.params.updated(**method_kwargs) + def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method) params.update(method_kwargs) From 8091f6027e4c4ebb4aa997b16fb288d0b95e71fc Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Wed, 17 Jul 2024 11:45:50 +0000 Subject: [PATCH 11/18] Fix pop --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1200f912fa..6df5b025f9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -778,7 +778,7 @@ def _save_or_select_or_merge( # note that the copy of extension handle itself the slicing of units when necessary and also the saveing sorted_extensions = _sort_extensions_by_dependency(self.extensions) # hack: quality metrics are computed at last - qm_extension_params = sorted_extensions.pop("quality_metrics") + qm_extension_params = sorted_extensions.pop("quality_metrics", None) if qm_extension_params is not None: sorted_extensions["quality_metrics"] = qm_extension_params recompute_dict = {} From 097706bf6975f53253ec11c54eeddfd849869728 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Wed, 17 Jul 2024 11:47:33 +0000 Subject: [PATCH 12/18] Extension job kwargs: check has_recording or has_temporary_recording --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 6df5b025f9..a4c25afed8 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1208,7 +1208,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar # check dependencies if extension_class.need_recording: - assert self.has_recording(), f"Extension {extension_name} requires the recording" + assert self.has_recording() or self.has_temporary_recording(), f"Extension {extension_name} requires the recording" for dependency_name in extension_class.depend_on: if "|" in dependency_name: ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) From 14615b3058275d6e3ddab6449759e71afc8c33d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:47:59 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a4c25afed8..3e92733974 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1208,7 +1208,9 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar # check dependencies if extension_class.need_recording: - assert self.has_recording() or self.has_temporary_recording(), f"Extension {extension_name} requires the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} requires the recording" for dependency_name in extension_class.depend_on: if "|" in dependency_name: ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) From 456d15bafb9dd10d282d858c48ff5d34a88d94cd Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 17 Jul 2024 09:13:12 -0400 Subject: [PATCH 14/18] fix dict method --- src/spikeinterface/postprocessing/unit_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 0aec6e155b..818f0a8062 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -51,7 +51,7 @@ def _handle_backward_compatibility_on_load(self): if "method_kwargs" in self.params: # make compatible analyzer created between february 24 and july 24 method_kwargs = self.params.pop("method_kwargs") - self.params.updated(**method_kwargs) + self.params.update(**method_kwargs) def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method) From 7bfccc28f5d07aa4ba29abcf40b559ca2904ce4c Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Thu, 18 Jul 2024 07:20:43 +0000 Subject: [PATCH 15/18] Fix templates backward compatibility --- .../core/waveforms_extractor_backwards_compatibility.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index dc983c6b18..c2ab7c8606 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -531,7 +531,11 @@ def _read_old_waveforms_extractor_binary(folder, sorting): templates[mode] = np.load(template_file) if len(templates) > 0: ext = ComputeTemplates(sorting_analyzer) - ext.params = dict(nbefore=nbefore, nafter=nafter, operators=list(templates.keys())) + ext.params = dict( + ms_before=params["ms_before"], + ms_after=params["ms_after"], + operators=list(templates.keys()) + ) for mode, arr in templates.items(): ext.data[mode] = arr sorting_analyzer.extensions["templates"] = ext @@ -548,6 +552,8 @@ def _read_old_waveforms_extractor_binary(folder, sorting): new_params = ext._set_params() updated_params = make_ext_params_up_to_date(ext, params, new_params) ext.set_params(**updated_params, save=False) + if ext.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() if new_name == "spike_amplitudes": amplitudes = [] From a22182e4e889a0c40a07e7d1bc834eeae4171d4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 07:22:18 +0000 Subject: [PATCH 16/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/waveforms_extractor_backwards_compatibility.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index c2ab7c8606..d7fad68b27 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -531,11 +531,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): templates[mode] = np.load(template_file) if len(templates) > 0: ext = ComputeTemplates(sorting_analyzer) - ext.params = dict( - ms_before=params["ms_before"], - ms_after=params["ms_after"], - operators=list(templates.keys()) - ) + ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys())) for mode, arr in templates.items(): ext.data[mode] = arr sorting_analyzer.extensions["templates"] = ext From 03c3d412f5b33c7f4b224ed0c0b77885a1770bae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 18 Jul 2024 14:44:38 +0200 Subject: [PATCH 17/18] Update src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../core/waveforms_extractor_backwards_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d7fad68b27..1c7676a302 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -617,7 +617,7 @@ def make_ext_params_up_to_date(ext, old_params, new_params): updated_params = old_params.copy() for p, values in old_params.items(): if p not in new_params: - warnings.warn(f"Removing legacy param {p} from {old_name} extension") + warnings.warn(f"Removing legacy parameter {p} from {old_name} extension") updated_params.pop(p) elif isinstance(values, dict): new_values = new_params.get(p, {}) From f2bc276c62ef41015320f6b9962bd37a3baecd95 Mon Sep 17 00:00:00 2001 From: alejoe91 Date: Thu, 18 Jul 2024 14:44:48 +0000 Subject: [PATCH 18/18] move update params --- .../waveforms_extractor_backwards_compatibility.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index c2ab7c8606..28878c6b57 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -548,12 +548,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext = new_class(sorting_analyzer) with open(ext_folder / "params.json", "r") as f: params = json.load(f) - # update params - new_params = ext._set_params() - updated_params = make_ext_params_up_to_date(ext, params, new_params) - ext.set_params(**updated_params, save=False) - if ext.need_backward_compatibility_on_load: - ext._handle_backward_compatibility_on_load() if new_name == "spike_amplitudes": amplitudes = [] @@ -610,6 +604,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting): pc_all[mask, ...] = pc_one ext.data["pca_projection"] = pc_all + # update params + new_params = ext._set_params() + updated_params = make_ext_params_up_to_date(ext, params, new_params) + ext.set_params(**updated_params, save=False) + if ext.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + sorting_analyzer.extensions[new_name] = ext return sorting_analyzer