diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index f52d236897..d6dbea3cc3 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -1,21 +1,16 @@ from __future__ import annotations import os -from pathlib import Path -from typing import Literal, Optional import shutil +from pathlib import Path +from typing import Optional import numpy as np import numpy.typing as npt from scipy.signal import welch from tqdm.auto import tqdm -from spikeinterface.core import ( - BinaryFolderRecording, - BinaryRecordingExtractor, - ChannelSparsity, - WaveformExtractor, -) +from spikeinterface.core import ChannelSparsity, SortingAnalyzer from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import ( export_to_phy, @@ -28,20 +23,19 @@ save_object_npy, ) + def export_to_ibl( - recording: BinaryRecordingExtractor | BinaryFolderRecording, - waveform_extractor: WaveformExtractor, + analyzer: SortingAnalyzer, output_folder: str | Path, - rms_win_length_sec = 3, - welch_win_length_samples = 1024, - total_secs = 100, + rms_win_length_sec=3, + welch_win_length_samples=1024, + total_secs=100, only_ibl_specific_steps=False, - compute_pc_features: bool = True, + compute_pc_features: bool = False, # shouldn't need these? compute_amplitudes: bool = True, sparsity: Optional[ChannelSparsity] = None, copy_binary: bool = True, remove_if_exists: bool = False, - peak_sign: Literal["both", "neg", "pos"] = "neg", template_mode: str = "median", dtype: Optional[npt.DTypeLike] = None, verbose: bool = True, @@ -49,14 +43,12 @@ def export_to_ibl( **job_kwargs, ): """ - Exports a waveform extractor to the IBL gui format (similar to the Phy format with some extras). + Exports a sorting analyzer to the IBL gui format (similar to the Phy format with some extras). Parameters ---------- - recording: BinaryRecordingExtractor | BinaryFolderRecording - The recording extractor or the recording folder. - waveform_extractor: a WaveformExtractor or None - If WaveformExtractor is provide then the compute is faster otherwise [?]. + analyzer: SortingAnalyzer + The sorting analyzer object. output_folder: str | Path The output folder where the phy template-gui files are saved rms_win_length_sec: float, default: 3 @@ -67,18 +59,16 @@ def export_to_ibl( The total number of seconds to use for the spectral density calculation. only_ibl_specific_steps: bool, default: False If True, only the IBL specific steps are run (i.e. skips calling `export_to_phy`) - compute_pc_features: bool, default: True + compute_pc_features: bool, default: False If True, pc features are computed compute_amplitudes: bool, default: True If True, waveforms amplitudes are computed sparsity: ChannelSparsity or None, default: None - The sparsity object + The sparsity object (currently only respected for phy part of the export) copy_binary: bool, default: True If True, the recording is copied and saved in the phy "output_folder" remove_if_exists: bool, default: False If True and "output_folder" exists, it is removed and overwritten - peak_sign: "neg" | "pos" | "both", default: "neg" - Used by compute_spike_amplitudes template_mode: str, default: "median" Parameter "mode" to be given to WaveformExtractor.get_template() dtype: dtype or None, default: None @@ -92,32 +82,61 @@ def export_to_ibl( """ + # Output folder checks if isinstance(output_folder, str): output_folder = Path(output_folder) - output_folder = Path(output_folder).absolute() if output_folder.is_dir(): if remove_if_exists: shutil.rmtree(output_folder) else: raise FileExistsError(f"{output_folder} already exists") + else: + pass + # don't make the output dir yet, b/c export_to_phy will do that for us. - # Start by just exporting to phy - if verbose: + if verbose: print("Exporting recording to IBL format...") + # Compute any missing extensions + available_extension_names = analyzer.get_saved_extension_names() + required_exts = [ + "templates", + "template_similarity", + "spike_locations", + "noise_levels", + "quality_metrics", + ] + required_qms = ["amplitude_median", "isi_violations_ratio", "amplitude_cutoff"] + for ext in required_exts: + if ext not in available_extension_names: + if ext == "quality_metrics": + kwargs = {"skip_pc_metrics": not compute_pc_features} + else: + kwargs = {} + analyzer.compute(ext, verbose=verbose, **kwargs) + elif ext == "quality_metrics": + qm = analyzer.get_extension("quality_metrics").get_data() + for rqm in required_qms: + if rqm not in qm: + analyzer.compute( + "quality_metrics", + metric_names=[rqm], + verbose=verbose, + ) + + # # Start by just exporting to phy if not only_ibl_specific_steps: - if verbose: + if verbose: print("Doing phy-like export...") export_to_phy( - waveform_extractor, + analyzer, output_folder, compute_amplitudes=compute_amplitudes, compute_pc_features=compute_pc_features, sparsity=sparsity, copy_binary=copy_binary, remove_if_exists=remove_if_exists, - peak_sign=peak_sign, template_mode=template_mode, dtype=dtype, verbose=verbose, @@ -130,32 +149,34 @@ def export_to_ibl( # Now we need to add the extra IBL specific files (channel_inds,) = np.isin( - recording.channel_ids, waveform_extractor.channel_ids + analyzer.recording.channel_ids, analyzer.channel_ids ).nonzero() ### Run spectral density and rms ### - fs_ap = recording.sampling_frequency + fs_ap = analyzer.recording.sampling_frequency rms_win_length_samples_ap = 2 ** np.ceil(np.log2(fs_ap * rms_win_length_sec)) - total_samples_ap = int(np.min([fs_ap * total_secs, recording.get_num_samples()])) + total_samples_ap = int( + np.min([fs_ap * total_secs, analyzer.recording.get_num_samples()]) + ) # the window generator will generates window indices wingen = WindowGenerator( ns=total_samples_ap, nswin=rms_win_length_samples_ap, overlap=0 ) win = { - "TRMS": np.zeros((wingen.nwin, recording.get_num_channels())), + "TRMS": np.zeros((wingen.nwin, analyzer.recording.get_num_channels())), "nsamples": np.zeros((wingen.nwin,)), "fscale": fscale(welch_win_length_samples, 1 / fs_ap, one_sided=True), "tscale": wingen.tscale(fs=fs_ap), } win["spectral_density"] = np.zeros( - (len(win["fscale"]), recording.get_num_channels()) + (len(win["fscale"]), analyzer.recording.get_num_channels()) ) # @Josh: this could be dramatically sped up if we employ SpikeInterface parallelization with tqdm(total=wingen.nwin) as pbar: for first, last in wingen.firstlast: - D = recording.get_traces(start_frame=first, end_frame=last).T + D = analyzer.recording.get_traces(start_frame=first, end_frame=last).T # remove low frequency noise below 1 Hz D = hp(D, 1 / fs_ap, [0, 1]) iw = wingen.iw @@ -206,16 +227,7 @@ def export_to_ibl( ### Save spike info ### - # Confirm spike locations are available - available_extension_names = waveform_extractor.get_available_extension_names() - if "spike_locations" not in available_extension_names: - from spikeinterface.postprocessing import compute_spike_locations - - compute_spike_locations( - waveform_extractor, verbose=verbose - ) # this should auto-save it - - spike_locations = waveform_extractor.load_extension("spike_locations").get_data() + spike_locations = analyzer.load_extension("spike_locations").get_data() spike_depths = spike_locations["y"] # convert clusters and squeeze @@ -242,23 +254,20 @@ def export_to_ibl( cluster_channels = [] cluster_peakToTrough = [] cluster_waveforms = [] - # num_chans = [] - - templates = waveform_extractor.get_all_templates() - # channel_locs = waveform_extractor.get_channel_locations() - extremum_channel_indices = get_template_extremum_channel( - waveform_extractor, outputs="index" - ) + templates = analyzer.get_extension("templates").get_data() + extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index") - for unit_idx, unit_id in enumerate(waveform_extractor.unit_ids): + for unit_idx, unit_id in enumerate(analyzer.unit_ids): waveform = templates[unit_idx, :, :] extremum_channel_index = extremum_channel_indices[unit_id] peak_waveform = waveform[:, extremum_channel_index] peakToTrough = ( np.argmax(peak_waveform) - np.argmin(peak_waveform) - ) / waveform_extractor.sampling_frequency + ) / analyzer.sampling_frequency # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units - cluster_channels.append(extremum_channel_index) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 + cluster_channels.append( + extremum_channel_index + ) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 cluster_peakToTrough.append(peakToTrough) cluster_waveforms.append(waveform) @@ -286,12 +295,16 @@ def export_to_ibl( os.rename(old_name, new_name) # save quality metrics - qm = waveform_extractor.load_extension("quality_metrics") + qm = analyzer.load_extension("quality_metrics") qm_data = qm.get_data() - qm_data.index.name = "cluster_id" qm_data["cluster_id.1"] = qm_data.index.values - + good_ibl = ( # rough estimate of ibl standards + (qm_data["amplitude_median"] > 50) + & (qm_data["isi_violations_ratio"] < 0.2) + & (qm_data["amplitude_cutoff"] < 0.1) + ) + qm_data["label"] = good_ibl.astype("int") qm_data.to_csv(output_folder / "clusters.metrics.csv")