Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Integrate trials object with Fano factor #645

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None):
def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
r"""
Evaluates the empirical Fano factor F of the spike counts of
a list of `neo.SpikeTrain` objects.
a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object.

Given the vector v containing the observed spike counts (one per
spike train) in the time window [t0, t1], F is defined as:
Expand All @@ -288,9 +288,10 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):

Parameters
----------
spiketrains : list
spiketrains : list or elephant.trials.Trial
List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
spike times for which to compute the Fano factor of spike counts.
spike times for which to compute the Fano factor of spike counts, or
an `elephant.trials.Trial` object containing multiple spiketrain lists.
warn_tolerance : pq.Quantity
In case of a list of input neo.SpikeTrains, if their durations vary by
more than `warn_tolerence` in their absolute values, throw a warning
Expand All @@ -299,10 +300,11 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):

Returns
-------
fano : float
fano : float or list of floats
The Fano factor of the spike counts of the input spike trains.
Returns np.NaN if an empty list is specified, or if all spike trains
are empty.
are empty. If a `Trial` object is provided, returns a list of Fano
factors, one for each trial.

Raises
------
Expand All @@ -328,29 +330,36 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
0.07142857142857142

"""
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

# Compute FF
if all(count == 0 for count in spike_counts):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
if not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
durations_max = max(durations)
if durations_max - durations_min > warn_tolerance.simplified.item():
warnings.warn("Fano factor calculated for spike trains of "
"different duration (minimum: {_min}s, maximum "
"{_max}s).".format(_min=durations_min,
_max=durations_max))

fano = spike_counts.var() / spike_counts.mean()
return fano
def _compute_fano(spiketrains):
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

# Compute FF
if all(count == 0 for count in spike_counts):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
if not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
durations_max = max(durations)
if durations_max - durations_min > warn_tolerance.simplified.item():
warnings.warn("Fano factor calculated for spike trains of "
"different duration (minimum: {_min}s, maximum "
"{_max}s).".format(_min=durations_min,
_max=durations_max))

fano = spike_counts.var() / spike_counts.mean()
return fano

if isinstance(spiketrains, elephant.trials.Trials):
return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx))
for idx in range(spiketrains.n_trials)]
else:
return _compute_fano(spiketrains)


def __variation_check(v, with_nan):
Expand Down
14 changes: 12 additions & 2 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from elephant import statistics
from elephant.spike_train_generation import StationaryPoissonProcess
from elephant.test.test_trials import _create_trials_block
from elephant.trials import TrialsFromBlock
from elephant.trials import TrialsFromBlock, TrialsFromLists


class IsiTestCase(unittest.TestCase):
Expand Down Expand Up @@ -289,12 +289,13 @@ def setUp(self):
# for cross-validation
self.sp_counts[i] = len(st)

self.test_trials = TrialsFromLists([self.test_spiketrains, self.test_spiketrains])

def test_fanofactor_spiketrains(self):
# Test with list of spiketrains
self.assertEqual(
np.var(self.sp_counts) / np.mean(self.sp_counts),
statistics.fanofactor(self.test_spiketrains))

# One spiketrain in list
st = self.test_spiketrains[0]
self.assertEqual(statistics.fanofactor([st]), 0.0)
Expand Down Expand Up @@ -352,6 +353,15 @@ def test_fanofactor_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, [st1],
warn_tolerance=1e-4)

def test_fanofactor_trials(self):
# Test with Trial object
self.assertEqual(
np.var(self.sp_counts) / np.mean(self.sp_counts),
statistics.fanofactor(self.test_trials)[0])
self.assertEqual(
np.var(self.sp_counts) / np.mean(self.sp_counts),
statistics.fanofactor(self.test_trials)[1])


class LVTestCase(unittest.TestCase):
def setUp(self):
Expand Down
Loading