Skip to content

Commit

Permalink
first draft ibl exporter
Browse files Browse the repository at this point in the history
  • Loading branch information
jonahpearl committed Nov 6, 2024
1 parent 29e5b61 commit ac8b481
Showing 1 changed file with 72 additions and 59 deletions.
131 changes: 72 additions & 59 deletions src/spikeinterface/exporters/to_ibl.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Literal, Optional
import shutil
from pathlib import Path
from typing import Optional

import numpy as np
import numpy.typing as npt
from scipy.signal import welch
from tqdm.auto import tqdm

from spikeinterface.core import (
BinaryFolderRecording,
BinaryRecordingExtractor,
ChannelSparsity,
WaveformExtractor,
)
from spikeinterface.core import ChannelSparsity, SortingAnalyzer
from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.exporters import (
export_to_phy,
Expand All @@ -28,35 +23,32 @@
save_object_npy,
)


def export_to_ibl(
recording: BinaryRecordingExtractor | BinaryFolderRecording,
waveform_extractor: WaveformExtractor,
analyzer: SortingAnalyzer,
output_folder: str | Path,
rms_win_length_sec = 3,
welch_win_length_samples = 1024,
total_secs = 100,
rms_win_length_sec=3,
welch_win_length_samples=1024,
total_secs=100,
only_ibl_specific_steps=False,
compute_pc_features: bool = True,
compute_pc_features: bool = False, # shouldn't need these?
compute_amplitudes: bool = True,
sparsity: Optional[ChannelSparsity] = None,
copy_binary: bool = True,
remove_if_exists: bool = False,
peak_sign: Literal["both", "neg", "pos"] = "neg",
template_mode: str = "median",
dtype: Optional[npt.DTypeLike] = None,
verbose: bool = True,
use_relative_path: bool = False,
**job_kwargs,
):
"""
Exports a waveform extractor to the IBL gui format (similar to the Phy format with some extras).
Exports a sorting analyzer to the IBL gui format (similar to the Phy format with some extras).
Parameters
----------
recording: BinaryRecordingExtractor | BinaryFolderRecording
The recording extractor or the recording folder.
waveform_extractor: a WaveformExtractor or None
If WaveformExtractor is provide then the compute is faster otherwise [?].
analyzer: SortingAnalyzer
The sorting analyzer object.
output_folder: str | Path
The output folder where the phy template-gui files are saved
rms_win_length_sec: float, default: 3
Expand All @@ -67,18 +59,16 @@ def export_to_ibl(
The total number of seconds to use for the spectral density calculation.
only_ibl_specific_steps: bool, default: False
If True, only the IBL specific steps are run (i.e. skips calling `export_to_phy`)
compute_pc_features: bool, default: True
compute_pc_features: bool, default: False
If True, pc features are computed
compute_amplitudes: bool, default: True
If True, waveforms amplitudes are computed
sparsity: ChannelSparsity or None, default: None
The sparsity object
The sparsity object (currently only respected for phy part of the export)
copy_binary: bool, default: True
If True, the recording is copied and saved in the phy "output_folder"
remove_if_exists: bool, default: False
If True and "output_folder" exists, it is removed and overwritten
peak_sign: "neg" | "pos" | "both", default: "neg"
Used by compute_spike_amplitudes
template_mode: str, default: "median"
Parameter "mode" to be given to WaveformExtractor.get_template()
dtype: dtype or None, default: None
Expand All @@ -92,32 +82,61 @@ def export_to_ibl(
"""

# Output folder checks
if isinstance(output_folder, str):
output_folder = Path(output_folder)

output_folder = Path(output_folder).absolute()
if output_folder.is_dir():
if remove_if_exists:
shutil.rmtree(output_folder)
else:
raise FileExistsError(f"{output_folder} already exists")
else:
pass
# don't make the output dir yet, b/c export_to_phy will do that for us.

# Start by just exporting to phy
if verbose:
if verbose:
print("Exporting recording to IBL format...")

# Compute any missing extensions
available_extension_names = analyzer.get_saved_extension_names()
required_exts = [
"templates",
"template_similarity",
"spike_locations",
"noise_levels",
"quality_metrics",
]
required_qms = ["amplitude_median", "isi_violations_ratio", "amplitude_cutoff"]
for ext in required_exts:
if ext not in available_extension_names:
if ext == "quality_metrics":
kwargs = {"skip_pc_metrics": not compute_pc_features}
else:
kwargs = {}
analyzer.compute(ext, verbose=verbose, **kwargs)
elif ext == "quality_metrics":
qm = analyzer.get_extension("quality_metrics").get_data()
for rqm in required_qms:
if rqm not in qm:
analyzer.compute(
"quality_metrics",
metric_names=[rqm],
verbose=verbose,
)

# # Start by just exporting to phy
if not only_ibl_specific_steps:
if verbose:
if verbose:
print("Doing phy-like export...")
export_to_phy(
waveform_extractor,
analyzer,
output_folder,
compute_amplitudes=compute_amplitudes,
compute_pc_features=compute_pc_features,
sparsity=sparsity,
copy_binary=copy_binary,
remove_if_exists=remove_if_exists,
peak_sign=peak_sign,
template_mode=template_mode,
dtype=dtype,
verbose=verbose,
Expand All @@ -130,32 +149,34 @@ def export_to_ibl(

# Now we need to add the extra IBL specific files
(channel_inds,) = np.isin(
recording.channel_ids, waveform_extractor.channel_ids
analyzer.recording.channel_ids, analyzer.channel_ids
).nonzero()

### Run spectral density and rms ###
fs_ap = recording.sampling_frequency
fs_ap = analyzer.recording.sampling_frequency
rms_win_length_samples_ap = 2 ** np.ceil(np.log2(fs_ap * rms_win_length_sec))
total_samples_ap = int(np.min([fs_ap * total_secs, recording.get_num_samples()]))
total_samples_ap = int(
np.min([fs_ap * total_secs, analyzer.recording.get_num_samples()])
)

# the window generator will generates window indices
wingen = WindowGenerator(
ns=total_samples_ap, nswin=rms_win_length_samples_ap, overlap=0
)
win = {
"TRMS": np.zeros((wingen.nwin, recording.get_num_channels())),
"TRMS": np.zeros((wingen.nwin, analyzer.recording.get_num_channels())),
"nsamples": np.zeros((wingen.nwin,)),
"fscale": fscale(welch_win_length_samples, 1 / fs_ap, one_sided=True),
"tscale": wingen.tscale(fs=fs_ap),
}
win["spectral_density"] = np.zeros(
(len(win["fscale"]), recording.get_num_channels())
(len(win["fscale"]), analyzer.recording.get_num_channels())
)

# @Josh: this could be dramatically sped up if we employ SpikeInterface parallelization
with tqdm(total=wingen.nwin) as pbar:
for first, last in wingen.firstlast:
D = recording.get_traces(start_frame=first, end_frame=last).T
D = analyzer.recording.get_traces(start_frame=first, end_frame=last).T
# remove low frequency noise below 1 Hz
D = hp(D, 1 / fs_ap, [0, 1])
iw = wingen.iw
Expand Down Expand Up @@ -206,16 +227,7 @@ def export_to_ibl(

### Save spike info ###

# Confirm spike locations are available
available_extension_names = waveform_extractor.get_available_extension_names()
if "spike_locations" not in available_extension_names:
from spikeinterface.postprocessing import compute_spike_locations

compute_spike_locations(
waveform_extractor, verbose=verbose
) # this should auto-save it

spike_locations = waveform_extractor.load_extension("spike_locations").get_data()
spike_locations = analyzer.load_extension("spike_locations").get_data()
spike_depths = spike_locations["y"]

# convert clusters and squeeze
Expand All @@ -242,23 +254,20 @@ def export_to_ibl(
cluster_channels = []
cluster_peakToTrough = []
cluster_waveforms = []
# num_chans = []

templates = waveform_extractor.get_all_templates()
# channel_locs = waveform_extractor.get_channel_locations()
extremum_channel_indices = get_template_extremum_channel(
waveform_extractor, outputs="index"
)
templates = analyzer.get_extension("templates").get_data()
extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index")

for unit_idx, unit_id in enumerate(waveform_extractor.unit_ids):
for unit_idx, unit_id in enumerate(analyzer.unit_ids):
waveform = templates[unit_idx, :, :]
extremum_channel_index = extremum_channel_indices[unit_id]
peak_waveform = waveform[:, extremum_channel_index]
peakToTrough = (
np.argmax(peak_waveform) - np.argmin(peak_waveform)
) / waveform_extractor.sampling_frequency
) / analyzer.sampling_frequency
# cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units
cluster_channels.append(extremum_channel_index) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870
cluster_channels.append(
extremum_channel_index
) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870
cluster_peakToTrough.append(peakToTrough)
cluster_waveforms.append(waveform)

Expand Down Expand Up @@ -286,12 +295,16 @@ def export_to_ibl(
os.rename(old_name, new_name)

# save quality metrics
qm = waveform_extractor.load_extension("quality_metrics")
qm = analyzer.load_extension("quality_metrics")
qm_data = qm.get_data()

qm_data.index.name = "cluster_id"
qm_data["cluster_id.1"] = qm_data.index.values

good_ibl = ( # rough estimate of ibl standards
(qm_data["amplitude_median"] > 50)
& (qm_data["isi_violations_ratio"] < 0.2)
& (qm_data["amplitude_cutoff"] < 0.1)
)
qm_data["label"] = good_ibl.astype("int")
qm_data.to_csv(output_folder / "clusters.metrics.csv")


Expand Down

0 comments on commit ac8b481

Please sign in to comment.