From 7135f28d0a653d41b9d11662c310c5f88ee498c1 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 6 Nov 2024 12:54:53 -0500 Subject: [PATCH] update analyzer extension checks --- src/spikeinterface/exporters/to_phy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 06041da231..4b3b914733 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -194,7 +194,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if not sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.get_extension("template_similarity") is None: sorting_analyzer.compute("template_similarity") template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() @@ -215,14 +215,14 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if not sorting_analyzer.has_extension("spike_amplitudes"): + if sorting_analyzer.get_extension("spike_amplitudes") is None: sorting_analyzer.compute("spike_amplitudes", **job_kwargs) amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() amplitudes = amplitudes[:, np.newaxis] np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if not sorting_analyzer.has_extension("principal_components"): + if sorting_analyzer.get_extension("principal_components") is None: sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) pca_extension = sorting_analyzer.get_extension("principal_components") @@ -250,7 +250,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if sorting_analyzer.has_extension("quality_metrics") and add_quality_metrics: + if sorting_analyzer.get_extension("quality_metrics") is not None and add_quality_metrics: qm_data = sorting_analyzer.get_extension("quality_metrics").get_data() for column_name in qm_data.columns: # already computed by phy @@ -259,7 +259,7 @@ def export_to_phy( {"cluster_id": [i for i in range(len(unit_ids))], column_name: qm_data[column_name].values} ) metric.to_csv(output_folder / f"cluster_{column_name}.tsv", sep="\t", index=False) - if sorting_analyzer.has_extension("template_metrics") and add_template_metrics: + if sorting_analyzer.get_extension("template_metrics") is not None and add_template_metrics: tm_data = sorting_analyzer.get_extension("template_metrics").get_data() for column_name in tm_data.columns: metric = pd.DataFrame(