Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement syncrhrony metrics #1951

Merged
merged 11 commits into from
Sep 7, 2023
2 changes: 2 additions & 0 deletions doc/modules/qualitymetrics/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ References

.. [Hruschka] Hruschka, E.R., de Castro, L.N., Campello R.J.G.B. "Evolutionary algorithms for clustering gene-expression data." Fourth IEEE International Conference on Data Mining (ICDM'04) 2004, pp 403-406.

.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007.

.. [IBL] International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022.

.. [Jackson] Jadin Jackson, Neil Schmitzer-Torbert, K.D. Harris, and A.D. Redish. Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Soc Neurosci Abstr, 518, 01 2005.
Expand Down
49 changes: 49 additions & 0 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Synchrony Metrics (:code:`synchrony`)
=====================================

Calculation
-----------
This function is providing a metric for the presence of synchronous spiking events across multiple spike trains.

The complexity is used to characterize synchronous events within the same spike train and across different spike
trains. This way synchronous events can be found both in multi-unit and single-unit spike trains.
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.



Expectation and use
-------------------

A larger value indicates a higher synchrony of the respective spike train with the other spike trains.
Larger values, especially for larger sizes, indicate a higher probability of noisy spikes in spike trains.

Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as qm
# Make recording, sorting and wvf_extractor object for your data.
synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8))
# synchrony is a tuple of dicts with the synchrony metrics for each unit


Links to original implementations
---------------------------------

The SpikeInterface implementation is a partial port of the low-level complexity functions from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_

References
----------

.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics

.. autofunction:: compute_synchrony_metrics

Literature
----------

Based on concepts described in Gruen_
1 change: 1 addition & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .generate import (
generate_recording,
generate_sorting,
add_synchrony_to_sorting,
create_sorting_npz,
generate_snippets,
synthesize_random_firings,
Expand Down
78 changes: 78 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,31 @@ def generate_sorting(
refractory_period_ms=3.0, # in ms
seed=None,
):
"""
Generates sorting object with random firings.

Parameters
----------
num_units : int, default: 5
Number of units
sampling_frequency : float, default: 30000.0
The sampling frequency
durations : list, default: [10.325, 3.5]
Duration of each segment in s
firing_rates : float, default: 3.0
The firing rate of each unit (in Hz).
empty_units : list, default: None
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
The refractory period in ms
seed : int, default: None
The random seed

Returns
-------
sorting : NumpySorting
The sorting object
"""
seed = _ensure_seed(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)
Expand Down Expand Up @@ -152,6 +177,59 @@ def generate_sorting(
return sorting


def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
"""
Generates sorting object with added synchronous events from an existing sorting objects.

Parameters
----------
sorting : BaseSorting
The sorting object
sync_event_ratio : float
The ratio of added synchronous spikes with respect to the total number of spikes.
E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra
spikes are synchronous (same sample_index), but on different units (not duplicates).
seed : int, default: None
The random seed


Returns
-------
sorting : NumpySorting
The sorting object

"""
rng = np.random.default_rng(seed)
spikes = sorting.to_spike_vector()
unit_ids = sorting.unit_ids

samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
# add syncrhonous events
num_sync = int(len(spikes) * sync_event_ratio)
spikes_duplicated = rng.choice(spikes, size=num_sync, replace=True)
# change unit_index
new_unit_indices = np.zeros(len(spikes_duplicated))
# make sure labels are all unique, keep unit_indices used for each spike
units_used_for_spike = {}
for i, spike in enumerate(spikes_duplicated):
sample_index = spike["sample_index"]
if sample_index not in units_used_for_spike:
units_used_for_spike[sample_index] = np.array([spike["unit_index"]])
units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])]

if len(units_not_used) == 0:
continue
new_unit_indices[i] = rng.choice(units_not_used)
units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i])
spikes_duplicated["unit_index"] = new_unit_indices
spikes_all = np.concatenate((spikes, spikes_duplicated))
sort_idxs = np.lexsort([spikes_all["sample_index"], spikes_all["segment_index"]])
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
spikes_all = spikes_all[sort_idxs]

sorting = NumpySorting(spikes=spikes_all, sampling_frequency=sorting.sampling_frequency, unit_ids=unit_ids)

return sorting


def create_sorting_npz(num_seg, file_path):
# create a NPZ sorting file
d = {}
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/postprocessing/tests/test_align_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting
from spikeinterface import NumpySorting
from spikeinterface.core import generate_sorting

from spikeinterface.postprocessing import align_sorting
Expand All @@ -17,8 +17,8 @@
cache_folder = Path("cache_folder") / "postprocessing"


def test_compute_unit_center_of_mass():
sorting = generate_sorting(durations=[10.0])
def test_align_sorting():
sorting = generate_sorting(durations=[10.0], seed=0)
print(sorting)

unit_ids = sorting.unit_ids
Expand All @@ -43,4 +43,4 @@ def test_compute_unit_center_of_mass():


if __name__ == "__main__":
test_compute_unit_center_of_mass()
test_align_sorting()
6 changes: 3 additions & 3 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_compute_correlograms(self):


def test_make_bins():
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5])
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

window_ms = 43.57
bin_ms = 1.6421
Expand Down Expand Up @@ -82,14 +82,14 @@ def test_equal_results_correlograms():
if HAVE_NUMBA:
methods.append("numba")

sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5])
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

_test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods)
_test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods)


def test_flat_cross_correlogram():
sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0])
sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0)

methods = ["numpy"]
if HAVE_NUMBA:
Expand Down
67 changes: 67 additions & 0 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,73 @@ def compute_sliding_rp_violations(
)


def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs):
"""
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.

Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.

Returns
-------
sync_spike_{X} : dict
The synchrony metric for synchrony size X.
Returns are as many as synchrony_sizes.

References
----------
Based on concepts described in [Gruen]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1"
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit()
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)

# Pre-allocate synchrony counts
synchrony_counts = {}
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64)

for segment_index in range(sorting.get_num_segments()):
spikes_in_segment = spikes[segment_index]

# we compute just by counting the occurrence of each sample_index
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)

# add counts for this segment
for unit_index in np.arange(len(sorting.unit_ids)):
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

# add counts for this segment
synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id]
for unit_index, unit_id in enumerate(sorting.unit_ids)
}
for synchrony_size in synchrony_sizes
}

# Convert dict to named tuple
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys())
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict)
return synchrony_metrics


_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4))


def compute_amplitude_cutoffs(
waveform_extractor,
peak_sign="neg",
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_amplitude_cutoffs,
compute_amplitude_medians,
compute_drift_metrics,
compute_synchrony_metrics,
)

from .pca_metrics import (
Expand Down Expand Up @@ -39,5 +40,6 @@
"sliding_rp_violation": compute_sliding_rp_violations,
"amplitude_cutoff": compute_amplitude_cutoffs,
"amplitude_median": compute_amplitude_medians,
"synchrony": compute_synchrony_metrics,
"drift": compute_drift_metrics,
}
Loading