From 81b13e8988342c52a02b2e88007b2b8b585310a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 16 Mar 2024 09:36:51 +0100 Subject: [PATCH] Fix visualize tests --- .../visualization/visualization.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface_pipelines/visualization/visualization.py b/src/spikeinterface_pipelines/visualization/visualization.py index 71cd873..9e0fa33 100644 --- a/src/spikeinterface_pipelines/visualization/visualization.py +++ b/src/spikeinterface_pipelines/visualization/visualization.py @@ -74,14 +74,17 @@ def visualize( decimation_factor = recording_params["drift"]["decimation_factor"] alpha = recording_params["drift"]["alpha"] - # use spike locations - if not waveform_extractor.has_extension("quality_metrics"): - logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations") - peaks = waveform_extractor.sorting.to_spike_vector() - peak_locations = waveform_extractor.load_extension("spike_locations").get_data() - peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data()) + # check if spike locations are available + spike_locations_available = False + if waveform_extractor is not None: + if waveform_extractor.has_extension("spike_locations"): + logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations") + peaks = waveform_extractor.sorting.to_spike_vector() + peak_locations = waveform_extractor.load_extension("spike_locations").get_data() + peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data()) + spike_locations_available = True # otherwise detect peaks - else: + if not spike_locations_available: from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass