Skip to content

Commit

Permalink
Merge pull request #3559 from chrishalcrow/hardcode-sync-sizes
Browse files Browse the repository at this point in the history
Hard code `synchony_sizes`
  • Loading branch information
alejoe91 authored Dec 5, 2024
2 parents d991382 + 2081916 commit c260df1
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 43 deletions.
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
'min_spikes': 0,
'window_size_s': 1},
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
'synchrony': {}
Since the recording is very short, let’s change some parameters to
Expand Down
4 changes: 2 additions & 2 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.



Expand All @@ -29,7 +29,7 @@ Example code
import spikeinterface.qualitymetrics as sqm
# Combine a sorting and recording into a sorting_analyzer
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
# synchrony is a tuple of dicts with the synchrony metrics for each unit
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
get_default_qm_params,
)
from .pca_metrics import get_quality_pca_metric_list
from .misc_metrics import get_synchrony_counts
from .misc_metrics import _get_synchrony_counts
33 changes: 17 additions & 16 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,18 +520,18 @@ def compute_sliding_rp_violations(
)


def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
"""
Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`.
Parameters
----------
spikes : np.array
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
synchrony_sizes : numpy array
The synchrony sizes to compute. Should be pre-sorted.
all_unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.
synchrony_sizes : None or np.array, default: None
The synchrony sizes to compute. Should be pre-sorted.
Returns
-------
Expand Down Expand Up @@ -565,37 +565,38 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
return synchrony_counts


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None):
"""
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
spikes at the exact same sample index, with synchrony sizes 2, 4 and 8.
Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. If None, all units are used.
synchrony_sizes: None, default: None
Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes.
Returns
-------
sync_spike_{X} : dict
The synchrony metric for synchrony size X.
Returns are as many as synchrony_sizes.
References
----------
Based on concepts described in [Grün]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
synchrony_sizes_np.sort()

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np])
if synchrony_sizes is not None:
warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`"
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)

synchrony_sizes = np.array([2, 4, 8])

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes])

sorting = sorting_analyzer.sorting

Expand All @@ -606,10 +607,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_

spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)
synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
for sync_idx, synchrony_size in enumerate(synchrony_sizes):
sync_id_metrics_dict = {}
for i, unit_id in enumerate(all_unit_ids):
if unit_id not in unit_ids:
Expand All @@ -623,7 +624,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
return res(**synchrony_metrics_dict)


_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
_default_params["synchrony"] = dict()


def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None):
Expand Down
39 changes: 16 additions & 23 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
get_synchrony_counts,
_get_synchrony_counts,
compute_quality_metrics,
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync():
one_spike["sample_index"] = spike_times
one_spike["unit_index"] = spike_units

sync_count = get_synchrony_counts(one_spike, np.array((2)), [0])
sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0])

assert np.all(sync_count[0] == np.array([0]))

Expand All @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync():
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1])
sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1])

assert np.all(sync_count[0] == np.array([1, 1]))

Expand All @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync():
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3])
sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3])

assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))
Expand All @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units():
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2])
sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2])

assert np.all(sync_count[0] == np.array([0, 1, 1]))

Expand Down Expand Up @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations):
def test_synchrony_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
sorting = sorting_analyzer.sorting
synchrony_sizes = (2, 3, 4)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes)
print(synchrony_metrics)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer)

synchrony_sizes = np.array([2, 4, 8])

# check returns
for size in synchrony_sizes:
Expand All @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple):
sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level)
sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory")

previous_synchrony_metrics = compute_synchrony_metrics(
previous_sorting_analyzer, synchrony_sizes=synchrony_sizes
)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes)
previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync)
print(current_synchrony_metrics)
# check that all values increased
for i, col in enumerate(previous_synchrony_metrics._fields):
Expand All @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):

unit_ids_subset = [3, 7]

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset)

assert list(synchrony_metrics.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7]


def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):

# all_unit_ids = sorting_analyzer_simple.sorting.unit_ids

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)

assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple)
assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids)


@pytest.mark.sortingcomponents
Expand Down

0 comments on commit c260df1

Please sign in to comment.