Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hard code synchony_sizes #3559

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
'min_spikes': 0,
'window_size_s': 1},
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
'synchrony': {}


Since the recording is very short, let’s change some parameters to
Expand Down
4 changes: 2 additions & 2 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.



Expand All @@ -29,7 +29,7 @@ Example code

import spikeinterface.qualitymetrics as sqm
# Combine a sorting and recording into a sorting_analyzer
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
# synchrony is a tuple of dicts with the synchrony metrics for each unit


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/qualitymetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
get_default_qm_params,
)
from .pca_metrics import get_quality_pca_metric_list
from .misc_metrics import get_synchrony_counts
from .misc_metrics import _get_synchrony_counts
31 changes: 15 additions & 16 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,18 +520,18 @@ def compute_sliding_rp_violations(
)


def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does change the order in order to have the default. I guess I don't see a way around this, so it's okay.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we event need the auto default if the user api already feeds this. We could have this be the same, but private no? Something to think about @alejoe91 @chrishalcrow

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you're right. The function is only used in compute_synchrony_metrics, and this allows people to use get_synchrony_counts with the synchrony_sizes in another context and keeps the order in tact. Nice.

def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])):
"""
Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`.

Parameters
----------
spikes : np.array
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
synchrony_sizes : numpy array
The synchrony sizes to compute. Should be pre-sorted.
all_unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.
synchrony_sizes : numpy array
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
The synchrony sizes to compute. Should be pre-sorted.

Returns
-------
Expand Down Expand Up @@ -565,37 +565,36 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
return synchrony_counts


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None):
"""
Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
spikes at the exact same sample index, with synchrony sizes 2, 4 and 8.

Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. If None, all units are used.

Returns
-------
sync_spike_{X} : dict
The synchrony metric for synchrony size X.
Returns are as many as synchrony_sizes.

References
----------
Based on concepts described in [Grün]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
synchrony_sizes_np.sort()

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np])
if synchrony_sizes is not None:
warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`"
warnings.warn(warning_message)
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved

synchrony_sizes = np.array([2, 4, 8])

res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes])

sorting = sorting_analyzer.sorting

Expand All @@ -606,10 +605,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_

spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)
synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
for sync_idx, synchrony_size in enumerate(synchrony_sizes):
sync_id_metrics_dict = {}
for i, unit_id in enumerate(all_unit_ids):
if unit_id not in unit_ids:
Expand All @@ -623,7 +622,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
return res(**synchrony_metrics_dict)


_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
_default_params["synchrony"] = dict()


def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None):
Expand Down
39 changes: 16 additions & 23 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
get_synchrony_counts,
_get_synchrony_counts,
compute_quality_metrics,
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync():
one_spike["sample_index"] = spike_times
one_spike["unit_index"] = spike_units

sync_count = get_synchrony_counts(one_spike, np.array((2)), [0])
sync_count = _get_synchrony_counts(one_spike, [0])

assert np.all(sync_count[0] == np.array([0]))

Expand All @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync():
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1])
sync_count = _get_synchrony_counts(two_spikes, [0, 1])

assert np.all(sync_count[0] == np.array([1, 1]))

Expand All @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync():
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3])
sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3])

assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))
Expand All @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units():
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2])
sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2])

assert np.all(sync_count[0] == np.array([0, 1, 1]))

Expand Down Expand Up @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations):
def test_synchrony_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
sorting = sorting_analyzer.sorting
synchrony_sizes = (2, 3, 4)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes)
print(synchrony_metrics)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer)

synchrony_sizes = np.array([2, 4, 8])

# check returns
for size in synchrony_sizes:
Expand All @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple):
sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level)
sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory")

previous_synchrony_metrics = compute_synchrony_metrics(
previous_sorting_analyzer, synchrony_sizes=synchrony_sizes
)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes)
previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer)
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync)
print(current_synchrony_metrics)
# check that all values increased
for i, col in enumerate(previous_synchrony_metrics._fields):
Expand All @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):

unit_ids_subset = [3, 7]

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset)

assert list(synchrony_metrics.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7]
assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7]


def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):

# all_unit_ids = sorting_analyzer_simple.sorting.unit_ids

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)

assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple)
assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids)


@pytest.mark.sortingcomponents
Expand Down