Skip to content

Commit

Permalink
[Fix] spike time tiling coefficient for unsorted spiketrains, added v…
Browse files Browse the repository at this point in the history
…alidation test (#564)

* refactor run_P
* refactor run_T
* add checks for t_start and t_stop in run_t
* add input checks and unittests
* add regression test for Issue #563
* add validation tests
* add check if spike times are sorted, if not sort the spikes
  • Loading branch information
Moritz-Alexander-Kern authored Oct 31, 2023
1 parent 971fc6a commit 8bac14c
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 103 deletions.
3 changes: 2 additions & 1 deletion doc/bib/elephant.bib
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ @article{Cutts2014_14288
number={43},
pages={14288--14303},
year={2014},
publisher={Soc Neuroscience}
publisher={Soc Neuroscience},
doi={10.1523/JNEUROSCI.2767-14.2014}
}

@article{Holt1996_1806,
Expand Down
182 changes: 90 additions & 92 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from scipy import integrate

from elephant.conversion import BinnedSpikeTrain
from elephant.utils import deprecated_alias
from elephant.utils import deprecated_alias, check_neo_consistency

__all__ = [
"covariance",
Expand Down Expand Up @@ -824,15 +824,17 @@ def cross_correlation_histogram(


@deprecated_alias(spiketrain_1='spiketrain_i', spiketrain_2='spiketrain_j')
def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain,
spiketrain_j: neo.core.SpikeTrain,
dt: pq.Quantity = 0.005 * pq.s) -> float:
"""
Calculates the Spike Time Tiling Coefficient (STTC) as described in
:cite:`correlation-Cutts2014_14288` following their implementation in C.
The STTC is a pairwise measure of correlation between spike trains.
It has been proposed as a replacement for the correlation index as it
presents several advantages (e.g. it's not confounded by firing rate,
appropriately distinguishes lack of correlation from anti-correlation,
periods of silence don't add to the correlation and it's sensitive to
periods of silence don't add to the correlation, and it's sensitive to
firing patterns).
The STTC is calculated as follows:
Expand All @@ -845,7 +847,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
in train 1, `PB` is the same proportion for the spikes in train 2;
`TA` is the proportion of total recording time within `[-dt, +dt]` of any
spike in train 1, TB is the same proportion for train 2.
For :math:`TA = PB = 1`and for :math:`TB = PA = 1`
For :math:`TA = PB = 1` and for :math:`TB = PA = 1`
the resulting :math:`0/0` is replaced with :math:`1`,
since every spike from the train with :math:`T = 1` is within
`[-dt, +dt]` of a spike of the other train.
Expand All @@ -857,7 +859,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
Parameters
----------
spiketrain_i, spiketrain_j : neo.SpikeTrain
spiketrain_i, spiketrain_j : :class:`neo.core.SpikeTrain`
Spike trains to cross-correlate. They must have the same `t_start` and
`t_stop`.
dt : pq.Quantity.
Expand All @@ -869,9 +871,9 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
Returns
-------
index : float or np.nan
The spike time tiling coefficient (STTC). Returns np.nan if any spike
train is empty.
index : :class:`float` or :obj:`numpy.nan`
The spike time tiling coefficient (STTC). Returns :obj:`numpy.nan` if
any spike train is empty.
Notes
-----
Expand All @@ -891,109 +893,105 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
0.4958601655933762
"""
# input checks
if dt <= 0 * pq.s:
raise ValueError(f"dt must be > 0, found: {dt}")

def run_P(spiketrain_i, spiketrain_j):
check_neo_consistency([spiketrain_j, spiketrain_i], neo.core.SpikeTrain)

if dt.units != spiketrain_i.units:
dt = dt.rescale(spiketrain_i.units)

def run_p(spiketrain_j: neo.core.SpikeTrain,
spiketrain_i: neo.core.SpikeTrain,
dt: pq.Quantity = dt) -> float:
"""
Check every spike in train 1 to see if there's a spike in train 2
within dt
Returns number of spikes in spiketrain_j which lie within +- dt of
any spike from spiketrain_i, divided by the total number of spikes in
spiketrain_j
"""
N2 = len(spiketrain_j)

# Search spikes of spiketrain_i in spiketrain_j
# ind will contain index of
ind = np.searchsorted(spiketrain_j.times, spiketrain_i.times)

# To prevent IndexErrors
# If a spike of spiketrain_i is after the last spike of spiketrain_j,
# the index is N2, however spiketrain_j[N2] raises an IndexError.
# By shifting this index, the spike of spiketrain_i will be compared
# to the last 2 spikes of spiketrain_j (negligible overhead).
# Note: Not necessary for index 0 that will be shifted to -1,
# because spiketrain_j[-1] is valid (additional negligible comparison)
ind[ind == N2] = N2 - 1

# Compare to nearest spike in spiketrain_j BEFORE spike in spiketrain_i
close_left = np.abs(
spiketrain_j.times[ind - 1] - spiketrain_i.times) <= dt
# Compare to nearest spike in spiketrain_j AFTER (or simultaneous)
# spike in spiketrain_j
close_right = np.abs(
spiketrain_j.times[ind] - spiketrain_i.times) <= dt

# spiketrain_j spikes that are in [-dt, dt] range of spiketrain_i
# spikes are counted only ONCE (as per original implementation)
close = close_left + close_right

# Count how many spikes in spiketrain_i have a "partner" in
# spiketrain_j
return np.count_nonzero(close)

def run_T(spiketrain):
# Create a boolean array where each element represents whether a spike
# in spiketrain_j lies within +- dt of any spike in spiketrain_i.
tiled_spikes_j = np.isclose(
spiketrain_j.times.magnitude[:, np.newaxis],
spiketrain_i.times.magnitude,
atol=dt.item())
# Determine which spikes in spiketrain_j satisfy the time window
# condition.
tiled_spike_indices = np.any(tiled_spikes_j, axis=1)
# Extract the spike times in spiketrain_j that satisfy the condition.
tiled_spikes_j = spiketrain_j[tiled_spike_indices]
# Calculate the ratio of matching spikes in j to the total spikes in j.
return len(tiled_spikes_j)/len(spiketrain_j)

def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float:
"""
Calculate the proportion of the total recording time 'tiled' by spikes.
"""
N = len(spiketrain)
time_A = 2 * N * dt # maximum possible time

if N == 1: # for only a single spike in the train

# Check difference between start of recording and single spike
if spiketrain[0] - spiketrain.t_start < dt:
time_A += - dt + spiketrain[0] - spiketrain.t_start

# Check difference between single spike and end of recording
elif spiketrain[0] + dt > spiketrain.t_stop:
time_A += - dt - spiketrain[0] + spiketrain.t_stop

else: # if more than a single spike in the train

# Calculate difference between consecutive spikes
diff = np.diff(spiketrain)

# Find spikes whose tiles overlap
idx = np.where(diff < 2 * dt)[0]
# Subtract overlapping "2*dt" tiles and add differences instead
time_A += - 2 * dt * len(idx) + diff[idx].sum()

# Check if spikes are within +/-dt of the start and/or end
# if so, subtract overlap of first and/or last spike
if (spiketrain[0] - spiketrain.t_start) < dt:
time_A += spiketrain[0] - dt - spiketrain.t_start
if (spiketrain.t_stop - spiketrain[N - 1]) < dt:
time_A += - spiketrain[-1] - dt + spiketrain.t_stop

# Calculate the proportion of total recorded time to "tiled" time
T = time_A / (spiketrain.t_stop - spiketrain.t_start)
return T.simplified.item() # enforce simplification, strip units
# Get the numerical value of 'dt'.
dt = dt.item()
# Get the start and stop times of the spike train.
t_start = spiketrain.t_start.item()
t_stop = spiketrain.t_stop.item()
# Get the spike times as a NumPy array.
sorted_spikes = spiketrain.times.magnitude
# Check if spikes are sorted and sort them if not.
if (np.diff(sorted_spikes) < 0).any():
sorted_spikes = np.sort(sorted_spikes)

# Calculate the time differences between consecutive spikes.
diff_spikes = np.diff(sorted_spikes)
# Calculate durations of spike overlaps within a time window of 2 * dt.
overlap_durations = diff_spikes[diff_spikes <= 2 * dt]
covered_time_overlap = np.sum(overlap_durations)

# Calculate the durations of non-overlapping spikes.
non_overlap_durations = diff_spikes[diff_spikes > 2 * dt]
covered_time_non_overlap = len(non_overlap_durations) * 2 * dt

# Check if the first and last spikes are within +/-dt of the start
# and end.
# If so, adjust the overlapping and non-overlapping times accordingly.
if sorted_spikes[0] - t_start < dt:
covered_time_overlap += sorted_spikes[0] - t_start
else:
covered_time_non_overlap += dt
if t_stop - sorted_spikes[- 1] < dt:
covered_time_overlap += t_stop - sorted_spikes[-1]
else:
covered_time_non_overlap += dt

N1 = len(spiketrain_i)
N2 = len(spiketrain_j)
# Calculate the total time covered by spikes and the total recording
# time.
total_time_covered = covered_time_overlap + covered_time_non_overlap
total_time = t_stop - t_start
# Calculate and return the proportion of the total recording time
# covered by spikes.
return total_time_covered / total_time

if N1 == 0 or N2 == 0:
if len(spiketrain_i) == 0 or len(spiketrain_j) == 0:
index = np.nan
else:
TA = run_T(spiketrain_i)
TB = run_T(spiketrain_j)
PA = run_P(spiketrain_i, spiketrain_j)
PA = PA / N1
PB = run_P(spiketrain_j, spiketrain_i)
PB = PB / N2
TA = run_t(spiketrain_j, dt)
TB = run_t(spiketrain_i, dt)
PA = run_p(spiketrain_j, spiketrain_i, dt)
PB = run_p(spiketrain_i, spiketrain_j, dt)

# check if the P and T values are 1 to avoid division by zero
# This only happens for TA = PB = 1 and/or TB = PA = 1,
# which leads to 0/0 in the calculation of the index.
# In those cases, every spike in the train with P = 1
# is within dt of a spike in the other train,
# so we set the respective (partial) index to 1.
if PA * TB == 1:
if PB * TA == 1:
index = 1.
else:
index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA)
if PA * TB == 1 and PB * TA == 1:
index = 1.
elif PA * TB == 1:
index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA)
elif PB * TA == 1:
index = 0.5 + 0.5 * (PA - TB) / (1 - PA * TB)
else:
index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / (
1 - PB * TA)
index = 0.5 * (PA - TB) / (1 - PA * TB) + \
0.5 * (PB - TA) / (1 - PB * TA)
return index


Expand Down
Loading

0 comments on commit 8bac14c

Please sign in to comment.