Skip to content

Commit

Permalink
respond to review
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 3, 2024
1 parent 45eb5b7 commit 2081916
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 6 additions & 4 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def compute_sliding_rp_violations(
)


def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4, 8])):
def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
"""
Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`.
Expand All @@ -530,7 +530,7 @@ def _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=np.array([2, 4,
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
all_unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.
synchrony_sizes : numpy array
synchrony_sizes : None or np.array, default: None
The synchrony sizes to compute. Should be pre-sorted.
Returns
Expand Down Expand Up @@ -576,6 +576,8 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N
A SortingAnalyzer object.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. If None, all units are used.
synchrony_sizes: None, default: None
Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes.
Returns
-------
Expand All @@ -590,7 +592,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N

if synchrony_sizes is not None:
warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`"
warnings.warn(warning_message)
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)

synchrony_sizes = np.array([2, 4, 8])

Expand All @@ -605,7 +607,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N

spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = _get_synchrony_counts(spikes, all_unit_ids, synchrony_sizes=synchrony_sizes)
synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync():
one_spike["sample_index"] = spike_times
one_spike["unit_index"] = spike_units

sync_count = _get_synchrony_counts(one_spike, [0])
sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0])

assert np.all(sync_count[0] == np.array([0]))

Expand All @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync():
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = _get_synchrony_counts(two_spikes, [0, 1])
sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1])

assert np.all(sync_count[0] == np.array([1, 1]))

Expand All @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync():
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = _get_synchrony_counts(four_spikes, [0, 1, 2, 3])
sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3])

assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))
Expand All @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units():
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = _get_synchrony_counts(three_spikes, [0, 1, 2])
sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2])

assert np.all(sync_count[0] == np.array([0, 1, 1]))

Expand Down

0 comments on commit 2081916

Please sign in to comment.