Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add synchrony metrics to quality metrics #1205

49 changes: 49 additions & 0 deletions doc/quality_metrics/synchrony_metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Synchrony Metrics (:code:`synchrony_metrics`)
=======================================

Calculation
-----------
This function is providing a metric for the presence of synchronous spiking events across multiple spike trains.

The complexity is used to characterize synchronous events within the same spike train and across different spike
trains. This way synchronous events can be found both in multi-unit and single-unit spike trains.

Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur separated by spread - 1 or less empty bins,
within and across spike trains in the spiketrains list.

Expectation and use
-------------------
A larger value indicates a higher synchrony of the respective spike train with the other spike trains.

Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as qm
# Make recording, sorting and wvf_extractor object for your data.
presence_ratio = qm.compute_synchrony_metrics(wvf_extractor)
# presence_ratio is a tuple of dicts with the synchrony metrics for each unit

Links to source code
--------------------

From `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_


References
----------

.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics

.. autofunction:: compute_synchrony_metrics

Literature
----------

Described in Gruen_

Citations
---------
.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data.
In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"threadpoolctl",
"tqdm",
"probeinterface>=0.2.14",
"elephant>=0.11.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be moved to the full extra, since elephant is used by submodules.

In addition, the function really only uses this function from elephant. I think it would be better to directly port this function (with proper references), so we don't have to rely on an additional dependency

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I see your point.

Please allow me some time to look into this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course :) no rush!

]

[build-system]
Expand Down Expand Up @@ -123,12 +124,12 @@ test = [
"pymde",
"torch",
"pynndescent",

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

]


Expand Down
102 changes: 94 additions & 8 deletions spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import numpy as np
import warnings
import scipy.ndimage
from elephant.spike_train_synchrony import Synchrotool
import quantities as pq
import neo


from ..core import get_noise_levels
from ..core.template_tools import (
Expand Down Expand Up @@ -103,7 +107,7 @@ def compute_presence_ratio(waveform_extractor, bin_duration_s=60):
waveform_extractor : WaveformExtractor
The waveform extractor object.
bin_duration_s : float, optional, default: 60
The duration of each bin in seconds. If the duration is less than this value,
The duration of each bin in seconds. If the duration is less than this value,
presence_ratio is set to NaN

Returns
Expand Down Expand Up @@ -277,12 +281,12 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms=
num_spikes += len(spike_train)
num_violations += np.sum(isis < isi_threshold_samples)
violation_time = 2 * num_spikes * (isi_threshold_s - min_isi_s)

if num_spikes > 0:
total_rate = num_spikes / total_duration
violation_rate = num_violations / violation_time
isi_violations_ratio[unit_id] = violation_rate / total_rate
isi_violations_count[unit_id] = num_violations
isi_violations_count[unit_id] = num_violations
else:
isi_violations_ratio[unit_id] = np.nan
isi_violations_count[unit_id] = np.nan
Expand All @@ -294,7 +298,7 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms=


_default_params["isi_violations"] = dict(
isi_threshold_ms=1.5,
isi_threshold_ms=1.5,
min_isi_ms=0
)

Expand Down Expand Up @@ -396,7 +400,7 @@ def compute_amplitudes_cutoff(waveform_extractor, peak_sign='neg',
Controls the smoothing applied to the amplitude histogram.
amplitudes_bins_min_ratio : int, optional, default: 5
The minimum ratio between number of amplitudes for a unit and the number of bins.
If the ratio is less than this threshold, the amplitude_cutoff for the unit is set
If the ratio is less than this threshold, the amplitude_cutoff for the unit is set
to NaN

Returns
Expand All @@ -413,7 +417,7 @@ def compute_amplitudes_cutoff(waveform_extractor, peak_sign='neg',
Notes
-----
This approach assumes the amplitude histogram is symmetric (not valid in the presence of drift).
If available, amplitudes are extracted from the "spike_amplitude" extension (recommended).
If available, amplitudes are extracted from the "spike_amplitude" extension (recommended).
If the "spike_amplitude" extension is not available, the amplitudes are extracted from the waveform extractor,
which usually has waveforms for a small subset of spikes (500 by default).
"""
Expand Down Expand Up @@ -462,13 +466,13 @@ def compute_amplitudes_cutoff(waveform_extractor, peak_sign='neg',
support = b[:-1]
bin_size = np.mean(np.diff(support))
peak_index = np.argmax(pdf)

pdf_above = np.abs(pdf[peak_index:] - pdf[0])

if len(np.where(pdf_above == pdf_above.min())[0]) > 1:
warnings.warn("Amplitude PDF does not have a unique minimum! More spikes might be required for a correct "
"amplitude_cutoff computation!")

G = np.argmin(pdf_above) + peak_index
fraction_missing = np.sum(pdf[G:]) * bin_size
fraction_missing = np.min([fraction_missing, 0.5])
Expand Down Expand Up @@ -517,3 +521,85 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters,
spike_train = spike_trains[spike_clusters == i]
n_v = _compute_nb_violations_numba(spike_train, t_r)
nb_rp_violations[i] += n_v


def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(0, 2), **kwargs):
"""Compute synchrony metrics for each unit and for each synchrony size.
Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
synchrony_sizes : list of int
Sizes of synchronous events to consider for synchrony metrics.
Returns
-------
synchrony_metrics : namedtuple
Synchrony metrics for each unit and for each synchrony size.
Notes
-----
This function uses the Synchrotool from the elephant library to compute synchrony metrics.
"""

sampling_rate=waveform_extractor.sorting.get_sampling_frequency()
# get a list of neo.SpikeTrains
spiketrains = _create_list_of_neo_spiketrains(waveform_extractor.sorting, sampling_rate)
# get spike counts
spike_counts = np.array([len(st) for st in spiketrains])
# avoid division by zero, for zero spikes we want metric = 0
spike_counts[spike_counts == 0] = 1

# Synchrony
synchrotool = Synchrotool(spiketrains, sampling_rate=sampling_rate*pq.Hz)
# free some memory
synchrotool.complexity_histogram = []
synchrotool.time_histogram = []
# annotate synchrofacts
synchrotool.annotate_synchrofacts()

# create a dictionary 'synchrony_metrics'
synchrony_metrics = {
# create a dictionary for each synchrony_size
f'syncSpike_{synchrony_size}': {
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved
# create a dictionary for each spiketrain
spiketrain.annotations['cluster_id']:
# count number of occurrences, where 'complexity' >= synchrony_size and divide by spike counts
np.count_nonzero(spiketrain.array_annotations['complexity'] >= synchrony_size) / spike_counts[idx]
for idx, spiketrain in enumerate(spiketrains)}
for synchrony_size in synchrony_sizes}

# Convert dict to named tuple
synchrony_metrics_tuple = namedtuple('SynchronyMetrics', synchrony_metrics.keys())
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics)
return synchrony_metrics


_default_params["synchrony_metrics"] = dict(
synchrony_sizes=(0, 2)
)

def _create_list_of_neo_spiketrains(sorting, sampling_rate):
""" create a list of neo.SpikeTrains from a SortingExtractor"""

def _create_neo_spiketrain(unit_id, segment_index):
"""Create a neo.SpikeTrain object from a unit_id and segment_index."""
unit_spiketrain = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
return neo.SpikeTrain(
unit_spiketrain * pq.ms,
t_stop=max(unit_spiketrain) * pq.ms if len(unit_spiketrain) != 0 else 1 * pq.ms,
sampling_rate=sampling_rate * pq.Hz,
cluster_id=unit_id)

unit_ids = sorting.unit_ids
num_segs = sorting.get_num_segments()

# create a list of neo.SpikeTrain
spiketrains = [_create_neo_spiketrain(unit_id, segment_index)
for unit_id in unit_ids for segment_index in range(num_segs)]

# set common t_start, t_stop for all spiketrains
t_start = min(st.t_start for st in spiketrains)
t_stop = max(st.t_stop for st in spiketrains) + 1*pq.s
for spiketrain in spiketrains:
spiketrain.t_start = t_start
spiketrain.t_stop = t_stop
return spiketrains
4 changes: 3 additions & 1 deletion spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
compute_isi_violations,
compute_refrac_period_violations,
compute_amplitudes_cutoff,
compute_synchrony_metrics,
)

from .pca_metrics import (
Expand All @@ -30,7 +31,8 @@
"snr" : compute_snrs,
"isi_violations" : compute_isi_violations,
"rp_violations" : compute_refrac_period_violations,
"amplitude_cutoff" : compute_amplitudes_cutoff
"amplitude_cutoff" : compute_amplitudes_cutoff,
"synchrony_metrics" : compute_synchrony_metrics,
}


Expand Down