Skip to content

Commit

Permalink
Merge branch 'main' into change_default_in_generate_recording
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 11, 2023
2 parents 2596d12 + a26cb84 commit ac65bc5
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 86 deletions.
2 changes: 2 additions & 0 deletions doc/modules/qualitymetrics/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ References
.. [Hruschka] Hruschka, E.R., de Castro, L.N., Campello R.J.G.B. "Evolutionary algorithms for clustering gene-expression data." Fourth IEEE International Conference on Data Mining (ICDM'04) 2004, pp 403-406.
.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007.
.. [IBL] International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022.
.. [Jackson] Jadin Jackson, Neil Schmitzer-Torbert, K.D. Harris, and A.D. Redish. Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Soc Neurosci Abstr, 518, 01 2005.
Expand Down
49 changes: 49 additions & 0 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Synchrony Metrics (:code:`synchrony`)
=====================================

Calculation
-----------
This function is providing a metric for the presence of synchronous spiking events across multiple spike trains.

The complexity is used to characterize synchronous events within the same spike train and across different spike
trains. This way synchronous events can be found both in multi-unit and single-unit spike trains.
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.



Expectation and use
-------------------

A larger value indicates a higher synchrony of the respective spike train with the other spike trains.
Larger values, especially for larger sizes, indicate a higher probability of noisy spikes in spike trains.

Example code
------------

.. code-block:: python
import spikeinterface.qualitymetrics as qm
# Make recording, sorting and wvf_extractor object for your data.
synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8))
# synchrony is a tuple of dicts with the synchrony metrics for each unit
Links to original implementations
---------------------------------

The SpikeInterface implementation is a partial port of the low-level complexity functions from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_

References
----------

.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics

.. autofunction:: compute_synchrony_metrics

Literature
----------

Based on concepts described in Gruen_
1 change: 1 addition & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .generate import (
generate_recording,
generate_sorting,
add_synchrony_to_sorting,
create_sorting_npz,
generate_snippets,
synthesize_random_firings,
Expand Down
78 changes: 78 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def generate_sorting(
refractory_period_ms=3.0, # in ms
seed=None,
):
"""
Generates sorting object with random firings.
Parameters
----------
num_units : int, default: 5
Number of units
sampling_frequency : float, default: 30000.0
The sampling frequency
durations : list, default: [10.325, 3.5]
Duration of each segment in s
firing_rates : float, default: 3.0
The firing rate of each unit (in Hz).
empty_units : list, default: None
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
The refractory period in ms
seed : int, default: None
The random seed
Returns
-------
sorting : NumpySorting
The sorting object
"""
seed = _ensure_seed(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)
Expand Down Expand Up @@ -157,6 +182,59 @@ def generate_sorting(
return sorting


def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
"""
Generates sorting object with added synchronous events from an existing sorting objects.
Parameters
----------
sorting : BaseSorting
The sorting object
sync_event_ratio : float
The ratio of added synchronous spikes with respect to the total number of spikes.
E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra
spikes are synchronous (same sample_index), but on different units (not duplicates).
seed : int, default: None
The random seed
Returns
-------
sorting : NumpySorting
The sorting object
"""
rng = np.random.default_rng(seed)
spikes = sorting.to_spike_vector()
unit_ids = sorting.unit_ids

# add syncrhonous events
num_sync = int(len(spikes) * sync_event_ratio)
spikes_duplicated = rng.choice(spikes, size=num_sync, replace=True)
# change unit_index
new_unit_indices = np.zeros(len(spikes_duplicated))
# make sure labels are all unique, keep unit_indices used for each spike
units_used_for_spike = {}
for i, spike in enumerate(spikes_duplicated):
sample_index = spike["sample_index"]
if sample_index not in units_used_for_spike:
units_used_for_spike[sample_index] = np.array([spike["unit_index"]])
units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])]

if len(units_not_used) == 0:
continue
new_unit_indices[i] = rng.choice(units_not_used)
units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i])
spikes_duplicated["unit_index"] = new_unit_indices
spikes_all = np.concatenate((spikes, spikes_duplicated))
sort_idxs = np.lexsort([spikes_all["sample_index"], spikes_all["segment_index"]])
spikes_all = spikes_all[sort_idxs]

sorting = NumpySorting(spikes=spikes_all, sampling_frequency=sorting.sampling_frequency, unit_ids=unit_ids)

return sorting


def create_sorting_npz(num_seg, file_path):
# create a NPZ sorting file
d = {}
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/postprocessing/tests/test_align_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting
from spikeinterface import NumpySorting
from spikeinterface.core import generate_sorting

from spikeinterface.postprocessing import align_sorting
Expand All @@ -17,8 +17,8 @@
cache_folder = Path("cache_folder") / "postprocessing"


def test_compute_unit_center_of_mass():
sorting = generate_sorting(durations=[10.0])
def test_align_sorting():
sorting = generate_sorting(durations=[10.0], seed=0)
print(sorting)

unit_ids = sorting.unit_ids
Expand All @@ -43,4 +43,4 @@ def test_compute_unit_center_of_mass():


if __name__ == "__main__":
test_compute_unit_center_of_mass()
test_align_sorting()
6 changes: 3 additions & 3 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_compute_correlograms(self):


def test_make_bins():
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5])
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

window_ms = 43.57
bin_ms = 1.6421
Expand Down Expand Up @@ -82,14 +82,14 @@ def test_equal_results_correlograms():
if HAVE_NUMBA:
methods.append("numba")

sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5])
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

_test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods)
_test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods)


def test_flat_cross_correlogram():
sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0])
sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0)

methods = ["numpy"]
if HAVE_NUMBA:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
means = means[None, :]
stds = np.std(random_data, axis=0)
stds = stds[None, :]
gain = 1 / stds
gain = 1.0 / stds
offset = -means / stds

if int_scale is not None:
Expand Down
15 changes: 10 additions & 5 deletions src/spikeinterface/preprocessing/tests/test_normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ def test_zscore():
assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01)
assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01)


def test_zscore_int():
seed = 1
rec = generate_recording(seed=seed, mode="legacy")
rec_int = scale(rec, dtype="int16", gain=100)
with pytest.raises(AssertionError):
rec4 = zscore(rec_int, dtype=None)
rec4 = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed)
tr = rec4.get_traces(segment_index=0)
trace_mean = np.mean(tr, axis=0)
trace_std = np.std(tr, axis=0)
zscore(rec_int, dtype=None)

zscore_recording = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed)
traces = zscore_recording.get_traces(segment_index=0)
trace_mean = np.mean(traces, axis=0)
trace_std = np.std(traces, axis=0)
assert np.all(np.abs(trace_mean) < 1)
assert np.all(np.abs(trace_std - 256) < 1)

Expand Down
67 changes: 67 additions & 0 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,73 @@ def compute_sliding_rp_violations(
)


def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs):
"""
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.
Returns
-------
sync_spike_{X} : dict
The synchrony metric for synchrony size X.
Returns are as many as synchrony_sizes.
References
----------
Based on concepts described in [Gruen]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1"
spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit()
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)

# Pre-allocate synchrony counts
synchrony_counts = {}
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64)

for segment_index in range(sorting.get_num_segments()):
spikes_in_segment = spikes[segment_index]

# we compute just by counting the occurrence of each sample_index
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)

# add counts for this segment
for unit_index in np.arange(len(sorting.unit_ids)):
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

# add counts for this segment
synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id]
for unit_index, unit_id in enumerate(sorting.unit_ids)
}
for synchrony_size in synchrony_sizes
}

# Convert dict to named tuple
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys())
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict)
return synchrony_metrics


_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4))


def compute_amplitude_cutoffs(
waveform_extractor,
peak_sign="neg",
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_amplitude_cutoffs,
compute_amplitude_medians,
compute_drift_metrics,
compute_synchrony_metrics,
)

from .pca_metrics import (
Expand Down Expand Up @@ -39,5 +40,6 @@
"sliding_rp_violation": compute_sliding_rp_violations,
"amplitude_cutoff": compute_amplitude_cutoffs,
"amplitude_median": compute_amplitude_medians,
"synchrony": compute_synchrony_metrics,
"drift": compute_drift_metrics,
}
Loading

0 comments on commit ac65bc5

Please sign in to comment.