diff --git a/doc/reference/phase_analysis.rst b/doc/reference/phase_analysis.rst index 48ac66dea..04f5f4e9b 100644 --- a/doc/reference/phase_analysis.rst +++ b/doc/reference/phase_analysis.rst @@ -3,3 +3,11 @@ Phase Analysis ============== .. automodule:: elephant.phase_analysis + +References +---------- + +.. bibliography:: ../bib/elephant.bib + :labelprefix: ph + :keyprefix: phase- + :style: unsrt diff --git a/elephant/phase_analysis.py b/elephant/phase_analysis.py index a7a785bd2..e195f6c58 100644 --- a/elephant/phase_analysis.py +++ b/elephant/phase_analysis.py @@ -9,6 +9,7 @@ phase_locking_value mean_phase_vector phase_difference + pairwise_phase_consistency weighted_phase_lag_index References @@ -31,6 +32,7 @@ __all__ = [ "spike_triggered_phase", + "pairwise_phase_consistency", "phase_locking_value", "mean_phase_vector", "phase_difference", @@ -161,8 +163,8 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): # Find index into signal for each spike ind_at_spike = ( - (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / - hilbert_transform[phase_i].sampling_period). \ + (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / + hilbert_transform[phase_i].sampling_period). \ simplified.magnitude.astype(int) # Append new list to the results for this spiketrain @@ -173,7 +175,7 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): # Step through all spikes for spike_i, ind_at_spike_j in enumerate(ind_at_spike): - if interpolate and ind_at_spike_j+1 < len(times): + if interpolate and ind_at_spike_j + 1 < len(times): # Get relative spike occurrence between the two closest signal # sample points # if z->0 spike is more to the left sample @@ -182,12 +184,14 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): hilbert_transform[phase_i].sampling_period # Save hilbert_transform (interpolate on circle) + p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j] ).item() p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1] ).item() interpolation = (1 - z) * np.exp(complex(0, p1)) \ + z * np.exp(complex(0, p2)) + p12 = np.angle([interpolation]) result_phases[spiketrain_i].append(p12) @@ -217,6 +221,91 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): return result_phases, result_amps, result_times +def pairwise_phase_consistency(phases, method='ppc0'): + r""" + The Pairwise Phase Consistency (PPC0) :cite:`phase-Vinck2010_51` is an + improved measure of phase consistency/phase locking value, accounting for + bias due to low trial counts. + + PPC0 is computed according to Eq. 14 and 15 of the cited paper. + + An improved version of the PPC (PPC1) :cite:`phase-Vinck2012_33` computes + angular difference ony between pairs of spikes within trials. + + PPC1 is not implemented yet + + + .. math:: + \text{PPC} = \frac{2}{N(N-1)} \sum_{j=1}^{N-1} \sum_{k=j+1}^N + f(\theta_j, \theta_k) + + wherein the function :math:`f` computes the dot product between two unit + vectors and is defined by + + .. math:: + f(\phi, \omega) = \cos(\phi) \cos(\omega) + \sin(\phi) \sin(\omega) + + Parameters + ---------- + phases : np.ndarray or list of np.ndarray + Spike-triggered phases (output from :func:`spike_triggered_phase`). + If phases is a list of arrays, each array is considered a trial + + method : str + 'ppc0' - compute PPC between all pairs of spikes + + Returns + ------- + result_ppc : list of float + Pairwise Phase Consistency + + """ + if isinstance(phases, np.ndarray): + phases = [phases] + if not isinstance(phases, (list, tuple)): + raise TypeError("Input must be a list of 1D numpy arrays with phases") + + for phase_array in phases: + if not isinstance(phase_array, np.ndarray): + raise TypeError("Each entry of the input list must be an 1D " + "numpy array with phases") + if phase_array.ndim != 1: + raise ValueError("Phase arrays must be 1D (use .flatten())") + + if method not in ['ppc0']: + raise ValueError('For method choose out of: ["ppc0"]') + + phase_array = np.hstack(phases) + n_trials = phase_array.shape[0] # 'spikes' are 'trials' as in paper + + # Compute the distance between each pair of phases using dot product + # Optimize computation time using array multiplications instead of for + # loops + p_cos_2d = np.broadcast_to(np.cos(phase_array), (n_trials, n_trials)) + p_sin_2d = np.broadcast_to(np.sin(phase_array), (n_trials, n_trials)) + + # By doing the element-wise multiplication of this matrix with its + # transpose, we get the distance between phases for all possible pairs + # of elements in 'phase' + dot_prod = np.multiply(p_cos_2d, p_cos_2d.T, dtype=np.float32) + \ + np.multiply(p_sin_2d, p_sin_2d.T, dtype=np.float32) + + # Now average over all elements in temp_results (the diagonal are 1 + # and should not be included) + np.fill_diagonal(dot_prod, 0) + + if method == 'ppc0': + # Note: each pair i,j is computed twice in dot_prod. do not + # multiply by 2. n_trial * n_trials - n_trials = nr of filled elements + # in dot_prod + ppc = np.sum(dot_prod) / (n_trials * n_trials - n_trials) + return ppc + + elif method == 'ppc1': + # TODO: remove all indices from the same trial + return + + def phase_locking_value(phases_i, phases_j): r""" Calculates the phase locking value (PLV) :cite:`phase-Lachaux99_194`. diff --git a/elephant/test/test_phase_analysis.py b/elephant/test/test_phase_analysis.py index 50b4f340f..7aeaa612c 100644 --- a/elephant/test/test_phase_analysis.py +++ b/elephant/test/test_phase_analysis.py @@ -203,6 +203,132 @@ def test_regression_269(self): self.assertEqual(len(phases_noint[0]), 2) +class PairwisePhaseConsistencyTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): # Note: using setUp makes the class call this + # function per test, while this way the function is called only + # 1 time per TestCase, slightly more efficient (0.5s tough) + + # Same setup as SpikeTriggerePhaseTestCase + tlen0 = 100 * pq.s + f0 = 20. * pq.Hz + fs0 = 1 * pq.ms + t0 = np.arange( + 0, tlen0.rescale(pq.s).magnitude, + fs0.rescale(pq.s).magnitude) * pq.s + cls.anasig0 = AnalogSignal( + np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), + units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + + # Spiketrain with perfect locking + cls.st_perfect = SpikeTrain( + np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms, + t_start=0 * pq.ms, t_stop=tlen0) + + # Spiketrain with inperfect locking + cls.st_inperfect = SpikeTrain( + [100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms, + t_start=0 * pq.ms, t_stop=tlen0) + + # Generate 2 'bursting' spiketrains, both locking on sinus period, + # but with different strengths + n_spikes = 3 # n spikes per burst + burst_interval = (1 / f0.magnitude) * pq.s + burst_start_times = np.arange( + 0, + tlen0.rescale('ms').magnitude, + burst_interval.rescale('ms').magnitude + ) + + # Spiketrain with strong locking + burst_freq_strong = 200. * pq.Hz # strongly locking unit + burst_spike_interval = (1 / burst_freq_strong.magnitude) * pq.s + st_in_burst = np.arange( + 0, + burst_spike_interval.rescale('ms').magnitude * n_spikes, + burst_spike_interval.rescale('ms').magnitude + ) + st = [st_in_burst + t_offset for t_offset in burst_start_times] + st = np.hstack(st) * pq.ms + cls.st_bursting_strong = SpikeTrain(st, + t_start=0 * pq.ms, + t_stop=tlen0 + ) + + # Spiketrain with weak locking + burst_freq_weak = 100. * pq.Hz # weak locking unit + burst_spike_interval = (1 / burst_freq_weak.magnitude) * pq.s + st_in_burst = np.arange( + 0, + burst_spike_interval.rescale('ms').magnitude * n_spikes, + burst_spike_interval.rescale('ms').magnitude + ) + st = [st_in_burst + t_offset for t_offset in burst_start_times] + st = np.hstack(st) * pq.ms + cls.st_bursting_weak = SpikeTrain(st, + t_start=0 * pq.ms, + t_stop=tlen0 + ) + + def test_perfect_locking(self): + phases, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st_perfect, + interpolate=True + ) + # Pass input as single array + ppc0 = elephant.phase_analysis.pairwise_phase_consistency( + phases[0], method='ppc0' + ) + self.assertEqual(ppc0, 1) + self.assertIsInstance(ppc0, float) + + # Pass input as list of arrays + n_phases = int(phases[0].shape[0] / 2) + phases_cut = [phases[0][i * 2:i * 2 + 2] for i in range(n_phases)] + ppc0 = elephant.phase_analysis.pairwise_phase_consistency( + phases_cut, method='ppc0' + ) + self.assertEqual(ppc0, 1) + self.assertIsInstance(ppc0, float) + + def test_inperfect_locking(self): + phases, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st_inperfect, + interpolate=True + ) + # Pass input as single array + ppc0 = elephant.phase_analysis.pairwise_phase_consistency( + phases[0], method='ppc0' + ) + self.assertLess(ppc0, 1) + self.assertIsInstance(ppc0, float) + + def test_strong_vs_weak_locking(self): + phases_weak, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st_bursting_weak, + interpolate=True + ) + # Pass input as single array + ppc0_weak = elephant.phase_analysis.pairwise_phase_consistency( + phases_weak[0], method='ppc0' + ) + phases_strong, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st_bursting_strong, + interpolate=True + ) + # Pass input as single array + ppc0_strong = elephant.phase_analysis.pairwise_phase_consistency( + phases_strong[0], method='ppc0' + ) + + self.assertLess(ppc0_weak, ppc0_strong) + + class MeanVectorTestCase(unittest.TestCase): def setUp(self): self.tolerance = 1e-15