diff --git a/doc/bib/elephant.bib b/doc/bib/elephant.bib index 9c4b30ac2..46ecea203 100644 --- a/doc/bib/elephant.bib +++ b/doc/bib/elephant.bib @@ -276,7 +276,8 @@ @article{Cutts2014_14288 number={43}, pages={14288--14303}, year={2014}, - publisher={Soc Neuroscience} + publisher={Soc Neuroscience}, + doi={10.1523/JNEUROSCI.2767-14.2014} } @article{Holt1996_1806, diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index b64667e45..da79abdf2 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -25,7 +25,7 @@ from scipy import integrate from elephant.conversion import BinnedSpikeTrain -from elephant.utils import deprecated_alias +from elephant.utils import deprecated_alias, check_neo_consistency __all__ = [ "covariance", @@ -824,7 +824,9 @@ def cross_correlation_histogram( @deprecated_alias(spiketrain_1='spiketrain_i', spiketrain_2='spiketrain_j') -def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): +def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain, + spiketrain_j: neo.core.SpikeTrain, + dt: pq.Quantity = 0.005 * pq.s) -> float: """ Calculates the Spike Time Tiling Coefficient (STTC) as described in :cite:`correlation-Cutts2014_14288` following their implementation in C. @@ -832,7 +834,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): It has been proposed as a replacement for the correlation index as it presents several advantages (e.g. it's not confounded by firing rate, appropriately distinguishes lack of correlation from anti-correlation, - periods of silence don't add to the correlation and it's sensitive to + periods of silence don't add to the correlation, and it's sensitive to firing patterns). The STTC is calculated as follows: @@ -845,7 +847,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): in train 1, `PB` is the same proportion for the spikes in train 2; `TA` is the proportion of total recording time within `[-dt, +dt]` of any spike in train 1, TB is the same proportion for train 2. - For :math:`TA = PB = 1`and for :math:`TB = PA = 1` + For :math:`TA = PB = 1` and for :math:`TB = PA = 1` the resulting :math:`0/0` is replaced with :math:`1`, since every spike from the train with :math:`T = 1` is within `[-dt, +dt]` of a spike of the other train. @@ -857,7 +859,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): Parameters ---------- - spiketrain_i, spiketrain_j : neo.SpikeTrain + spiketrain_i, spiketrain_j : :class:`neo.core.SpikeTrain` Spike trains to cross-correlate. They must have the same `t_start` and `t_stop`. dt : pq.Quantity. @@ -869,9 +871,9 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): Returns ------- - index : float or np.nan - The spike time tiling coefficient (STTC). Returns np.nan if any spike - train is empty. + index : :class:`float` or :obj:`numpy.nan` + The spike time tiling coefficient (STTC). Returns :obj:`numpy.nan` if + any spike train is empty. Notes ----- @@ -891,109 +893,105 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s): 0.4958601655933762 """ + # input checks + if dt <= 0 * pq.s: + raise ValueError(f"dt must be > 0, found: {dt}") - def run_P(spiketrain_i, spiketrain_j): + check_neo_consistency([spiketrain_j, spiketrain_i], neo.core.SpikeTrain) + + if dt.units != spiketrain_i.units: + dt = dt.rescale(spiketrain_i.units) + + def run_p(spiketrain_j: neo.core.SpikeTrain, + spiketrain_i: neo.core.SpikeTrain, + dt: pq.Quantity = dt) -> float: """ - Check every spike in train 1 to see if there's a spike in train 2 - within dt + Returns number of spikes in spiketrain_j which lie within +- dt of + any spike from spiketrain_i, divided by the total number of spikes in + spiketrain_j """ - N2 = len(spiketrain_j) - - # Search spikes of spiketrain_i in spiketrain_j - # ind will contain index of - ind = np.searchsorted(spiketrain_j.times, spiketrain_i.times) - - # To prevent IndexErrors - # If a spike of spiketrain_i is after the last spike of spiketrain_j, - # the index is N2, however spiketrain_j[N2] raises an IndexError. - # By shifting this index, the spike of spiketrain_i will be compared - # to the last 2 spikes of spiketrain_j (negligible overhead). - # Note: Not necessary for index 0 that will be shifted to -1, - # because spiketrain_j[-1] is valid (additional negligible comparison) - ind[ind == N2] = N2 - 1 - - # Compare to nearest spike in spiketrain_j BEFORE spike in spiketrain_i - close_left = np.abs( - spiketrain_j.times[ind - 1] - spiketrain_i.times) <= dt - # Compare to nearest spike in spiketrain_j AFTER (or simultaneous) - # spike in spiketrain_j - close_right = np.abs( - spiketrain_j.times[ind] - spiketrain_i.times) <= dt - - # spiketrain_j spikes that are in [-dt, dt] range of spiketrain_i - # spikes are counted only ONCE (as per original implementation) - close = close_left + close_right - - # Count how many spikes in spiketrain_i have a "partner" in - # spiketrain_j - return np.count_nonzero(close) - - def run_T(spiketrain): + # Create a boolean array where each element represents whether a spike + # in spiketrain_j lies within +- dt of any spike in spiketrain_i. + tiled_spikes_j = np.isclose( + spiketrain_j.times.magnitude[:, np.newaxis], + spiketrain_i.times.magnitude, + atol=dt.item()) + # Determine which spikes in spiketrain_j satisfy the time window + # condition. + tiled_spike_indices = np.any(tiled_spikes_j, axis=1) + # Extract the spike times in spiketrain_j that satisfy the condition. + tiled_spikes_j = spiketrain_j[tiled_spike_indices] + # Calculate the ratio of matching spikes in j to the total spikes in j. + return len(tiled_spikes_j)/len(spiketrain_j) + + def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float: """ Calculate the proportion of the total recording time 'tiled' by spikes. """ - N = len(spiketrain) - time_A = 2 * N * dt # maximum possible time - - if N == 1: # for only a single spike in the train - - # Check difference between start of recording and single spike - if spiketrain[0] - spiketrain.t_start < dt: - time_A += - dt + spiketrain[0] - spiketrain.t_start - - # Check difference between single spike and end of recording - elif spiketrain[0] + dt > spiketrain.t_stop: - time_A += - dt - spiketrain[0] + spiketrain.t_stop - - else: # if more than a single spike in the train - - # Calculate difference between consecutive spikes - diff = np.diff(spiketrain) - - # Find spikes whose tiles overlap - idx = np.where(diff < 2 * dt)[0] - # Subtract overlapping "2*dt" tiles and add differences instead - time_A += - 2 * dt * len(idx) + diff[idx].sum() - - # Check if spikes are within +/-dt of the start and/or end - # if so, subtract overlap of first and/or last spike - if (spiketrain[0] - spiketrain.t_start) < dt: - time_A += spiketrain[0] - dt - spiketrain.t_start - if (spiketrain.t_stop - spiketrain[N - 1]) < dt: - time_A += - spiketrain[-1] - dt + spiketrain.t_stop - - # Calculate the proportion of total recorded time to "tiled" time - T = time_A / (spiketrain.t_stop - spiketrain.t_start) - return T.simplified.item() # enforce simplification, strip units + # Get the numerical value of 'dt'. + dt = dt.item() + # Get the start and stop times of the spike train. + t_start = spiketrain.t_start.item() + t_stop = spiketrain.t_stop.item() + # Get the spike times as a NumPy array. + sorted_spikes = spiketrain.times.magnitude + # Check if spikes are sorted and sort them if not. + if (np.diff(sorted_spikes) < 0).any(): + sorted_spikes = np.sort(sorted_spikes) + + # Calculate the time differences between consecutive spikes. + diff_spikes = np.diff(sorted_spikes) + # Calculate durations of spike overlaps within a time window of 2 * dt. + overlap_durations = diff_spikes[diff_spikes <= 2 * dt] + covered_time_overlap = np.sum(overlap_durations) + + # Calculate the durations of non-overlapping spikes. + non_overlap_durations = diff_spikes[diff_spikes > 2 * dt] + covered_time_non_overlap = len(non_overlap_durations) * 2 * dt + + # Check if the first and last spikes are within +/-dt of the start + # and end. + # If so, adjust the overlapping and non-overlapping times accordingly. + if sorted_spikes[0] - t_start < dt: + covered_time_overlap += sorted_spikes[0] - t_start + else: + covered_time_non_overlap += dt + if t_stop - sorted_spikes[- 1] < dt: + covered_time_overlap += t_stop - sorted_spikes[-1] + else: + covered_time_non_overlap += dt - N1 = len(spiketrain_i) - N2 = len(spiketrain_j) + # Calculate the total time covered by spikes and the total recording + # time. + total_time_covered = covered_time_overlap + covered_time_non_overlap + total_time = t_stop - t_start + # Calculate and return the proportion of the total recording time + # covered by spikes. + return total_time_covered / total_time - if N1 == 0 or N2 == 0: + if len(spiketrain_i) == 0 or len(spiketrain_j) == 0: index = np.nan else: - TA = run_T(spiketrain_i) - TB = run_T(spiketrain_j) - PA = run_P(spiketrain_i, spiketrain_j) - PA = PA / N1 - PB = run_P(spiketrain_j, spiketrain_i) - PB = PB / N2 + TA = run_t(spiketrain_j, dt) + TB = run_t(spiketrain_i, dt) + PA = run_p(spiketrain_j, spiketrain_i, dt) + PB = run_p(spiketrain_i, spiketrain_j, dt) + # check if the P and T values are 1 to avoid division by zero # This only happens for TA = PB = 1 and/or TB = PA = 1, # which leads to 0/0 in the calculation of the index. # In those cases, every spike in the train with P = 1 # is within dt of a spike in the other train, # so we set the respective (partial) index to 1. - if PA * TB == 1: - if PB * TA == 1: - index = 1. - else: - index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA) + if PA * TB == 1 and PB * TA == 1: + index = 1. + elif PA * TB == 1: + index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA) elif PB * TA == 1: index = 0.5 + 0.5 * (PA - TB) / (1 - PA * TB) else: - index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / ( - 1 - PB * TA) + index = 0.5 * (PA - TB) / (1 - PA * TB) + \ + 0.5 * (PB - TA) / (1 - PB * TA) return index diff --git a/elephant/test/test_spike_train_correlation.py b/elephant/test/test_spike_train_correlation.py index 547b12611..ff4d088ca 100644 --- a/elephant/test/test_spike_train_correlation.py +++ b/elephant/test/test_spike_train_correlation.py @@ -6,18 +6,20 @@ :license: Modified BSD, see LICENSE.txt for details. """ +import math import unittest import neo +from neo.io import NixIO import numpy as np import quantities as pq from numpy.testing import assert_array_equal, assert_array_almost_equal import elephant.conversion as conv import elephant.spike_train_correlation as sc +from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR from elephant.spike_train_generation import homogeneous_poisson_process, \ homogeneous_gamma_process -import math class CovarianceTestCase(unittest.TestCase): @@ -727,8 +729,27 @@ def setUp(self): self.st_2 = neo.SpikeTrain( self.test_array_1d_2, units='ms', t_stop=50.) - def test_sttc(self): + def test_sttc_dt_smaller_zero(self): + self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, + dt=0 * pq.s) + self.assertRaises(ValueError, sc.sttc, self.st_1, self.st_2, + dt=-1 * pq.ms) + + def test_sttc_different_t_stop(self): + st_1 = neo.SpikeTrain([1], units='ms', t_stop=10.) + st_2 = neo.SpikeTrain([5], units='ms', t_stop=10.) + st_2.t_stop = 1 * pq.ms + self.assertRaises(ValueError, sc.sttc, st_1, st_2) + + def test_sttc_different_t_start(self): + st_1 = neo.SpikeTrain([1], units='ms', t_stop=10.) + st_2 = neo.SpikeTrain([5], units='ms', t_stop=10.) + st_2.t_start = 1 * pq.ms + self.assertRaises(ValueError, sc.sttc, st_1, st_2) + + def test_sttc_different_units_dt(self): # test for result + # target obtained with pencil and paper according to original paper. target = 0.495860165593 self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, 0.005 * pq.s)) @@ -737,30 +758,111 @@ def test_sttc(self): self.assertAlmostEqual(target, sc.sttc(self.st_1, self.st_2, 5.0 * pq.ms)) + def test_sttc_different_units_spiketrains(self): + st1 = neo.SpikeTrain([1], units='ms', t_stop=10.) + st2 = neo.SpikeTrain([5], units='s', t_stop=10.) + self.assertRaises(ValueError, sc.sttc, st1, st2) + + def test_sttc_not_enough_spiketrains(self): # test no spiketrains - self.assertTrue(np.isnan(sc.sttc([], []))) + self.assertRaises(TypeError, sc.sttc, [], []) # test one spiketrain - self.assertTrue(np.isnan(sc.sttc(self.st_1, []))) + self.assertRaises(TypeError, sc.sttc, self.st_1, []) + def test_sttc_one_spike(self): # test for one spike in a spiketrain - st1 = neo.SpikeTrain([1], units='ms', t_stop=1.) + st1 = neo.SpikeTrain([1], units='ms', t_stop=10.) st2 = neo.SpikeTrain([5], units='ms', t_stop=10.) self.assertEqual(sc.sttc(st1, st2), 1.0) self.assertTrue(bool(sc.sttc(st1, st2, 0.1 * pq.ms) < 0)) + def test_sttc_high_value_dt(self): # test for high value of dt self.assertEqual(sc.sttc(self.st_1, self.st_2, dt=5 * pq.s), 1.0) + def test_sttc_edge_cases(self): # test for TA = PB = 1 but TB /= PA /= 1 and vice versa + st2 = neo.SpikeTrain([5], units='ms', t_stop=10.) st3 = neo.SpikeTrain([1, 5, 9], units='ms', t_stop=10.) target2 = 1. / 3. - self.assertAlmostEqual(target2, sc.sttc(st3, st2, - 0.003 * pq.s)) - self.assertAlmostEqual(target2, sc.sttc(st2, st3, - 0.003 * pq.s)) - def test_exist_alias(self): + self.assertAlmostEqual(target2, sc.sttc(st3, st2, 0.003 * pq.s)) + self.assertAlmostEqual(target2, sc.sttc(st2, st3, 0.003 * pq.s)) + + def test_sttc_unsorted_spiketimes(self): + # regression test for issue #563 + # https://github.com/NeuralEnsemble/elephant/issues/563 + spiketrain_E7 = neo.SpikeTrain( + [1678., 23786.3, 34641.8, 71520.7, 73606.9, 78383.3, + 97387.9, 144313.4, 4607.6, 19275.1, 152894.2, 44240.1], + units='ms', t_stop=300000 * pq.ms) + + spiketrain_E3 = neo.SpikeTrain( + [1678., 23786.3, 34641.8, 71520.7, 73606.9, 78383.3, + 97387.9, 144313.4, 4607.6, 19275.1, 152894.2, 44240.1], + units='ms', t_stop=300000 * pq.ms) + sttc_unsorted_E7_E3 = sc.sttc(spiketrain_E7, + spiketrain_E3, dt=0.10 * pq.s) + self.assertAlmostEqual(sttc_unsorted_E7_E3, 1) + spiketrain_E7.sort() + spiketrain_E3.sort() + sttc_sorted_E7_E3 = sc.sttc(spiketrain_E7, + spiketrain_E3, dt=0.10 * pq.s) + self.assertAlmostEqual(sttc_unsorted_E7_E3, sttc_sorted_E7_E3) + + spiketrain_E8 = neo.SpikeTrain( + [20646.8, 25875.1, 26154.4, 35121., 55909.7, 79164.8, + 110849.8, 117484.1, 3731.5, 4213.9, 119995.1, 123748.1, + 171016.8, 172989., 185145.2, 12043.5, 185995.9, 186740.1, + 12629.8, 23394.3, 34993.2], units='ms', t_stop=300000 * pq.ms) + + spiketrain_B3 = neo.SpikeTrain( + [10600.7, 19699.6, 22803., 40769.3, 121385.7, 127402.9, + 130829.2, 134363.8, 1193.5, 8012.7, 142037.3, 146628.2, + 165925.3, 168489.3, 175194.3, 10339.8, 178676.4, 180807.2, + 201431.3, 22231.1, 38113.4], units='ms', t_stop=300000 * pq.ms) + + self.assertTrue( + sc.sttc(spiketrain_E8, spiketrain_B3, dt=0.10 * pq.s) < 1) + + sttc_unsorted_E8_B3 = sc.sttc(spiketrain_E8, + spiketrain_B3, dt=0.10 * pq.s) + spiketrain_E8.sort() + spiketrain_B3.sort() + sttc_sorted_E8_B3 = sc.sttc(spiketrain_E8, + spiketrain_B3, dt=0.10 * pq.s) + self.assertAlmostEqual(sttc_unsorted_E8_B3, sttc_sorted_E8_B3) + + def test_sttc_validation_test(self): + """This test checks the results of elephants implementation of + the spike time tiling coefficient against the results of the + original c-implementation. + The c-code and the test data is located at + NeuralEnsemble/elephant-data/unittest/spike_train_correlation/ + spike_time_tiling_coefficient""" + + repo_path = r"unittest/spike_train_correlation/spike_time_tiling_coefficient/data" # noqa + + files_to_download = [("spike_time_tiling_coefficient_results.nix", + "e3749d79046622494660a03e89950f51")] + + for filename, checksum in files_to_download: + filepath = download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum) + + reader = NixIO(filepath, mode='ro') + test_data_block = reader.read() + + for segment in test_data_block[0].segments: + spiketrain_i = segment.spiketrains[0] + spiketrain_j = segment.spiketrains[1] + dt = segment.annotations['dt'] + sttc_result = segment.annotations['sttc_result'] + self.assertAlmostEqual(sc.sttc(spiketrain_i, spiketrain_j, dt), + sttc_result) + + def test_sttc_exist_alias(self): # Test if alias cch still exists. self.assertEqual(sc.spike_time_tiling_coefficient, sc.sttc)