Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional quality metrics #1981

Merged
merged 19 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions doc/modules/qualitymetrics/amplitude_spread.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Amplitude spread (:code:`amplitude_spread`)
===========================================


Calculation
-----------

The amplitude spread is a measure of the amplitude variability.
It is computed as the ratio between the standard deviation and the amplitude mean (aka the coefficient of variation).
To obtain a better estimate of this measure, it is first computed separately for several bins of a prefixed number of spikes
(e.g. 100) and then the median of these values is taken.

The computation requires either spike amplitudes (see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes()`)
or amplitude scalings (see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings()`) to be pre-computed.


Expectation and use
-------------------

Very high levels of amplitude_spread ranges, outside of a physiological range, might indicate noise contamination.


Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as qm

# Make recording, sorting and wvf_extractor object for your data.
# It is required to run `compute_spike_amplitudes(wvf_extractor)` or
# `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN)
amplitude_spread = qm.compute_firing_ranges(wvf_extractor, amplitude_extension='spike_amplitudes')
# amplitude_spread is a dict containing the units' IDs as keys,
# and their amplitude_spread (in units of standard deviation).



References
----------

.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_spreads


Literature
----------

Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino.
1 change: 1 addition & 0 deletions doc/modules/qualitymetrics/drift.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Example code

import spikeinterface.qualitymetrics as qm

# Make recording, sorting and wvf_extractor object for your data.
# It is required to run `compute_spike_locations(wvf_extractor)`
# (if missing, values will be NaN)
drift_ptps, drift_stds, drift_mads = qm.compute_drift_metrics(wvf_extractor, peak_sign="neg")
Expand Down
40 changes: 40 additions & 0 deletions doc/modules/qualitymetrics/firing_range.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Firing range (:code:`firing_range`)
===================================


Calculation
-----------

The firing range indicates the dispersion of the firing rate of a unit across the recording. It is computed by
taking the difference between the 95th percentile's firing rate and the 5th percentile's firing rate computed over short time bins (e.g. 10 s).



Expectation and use
-------------------

Very high levels of firing ranges, outside of a physiological range, might indicate noise contamination.


Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as qm

# Make recording, sorting and wvf_extractor object for your data.
firing_range = qm.compute_firing_ranges(wvf_extractor)
# firing_range is a dict containing the units' IDs as keys,
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
# and their firing firing_range as values (in Hz).

References
----------

.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges


Literature
----------

Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino.
124 changes: 123 additions & 1 deletion src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,129 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
return synchrony_metrics


_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4))
_default_params["synchrony"] = dict(synchrony_sizes=(0, 2, 4))


def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0.95), unit_ids=None):
"""Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution
computed in non-overlapping time bins.

Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
bin_size_s : float, default: 5
The size of the bin in seconds.
percentiles : tuple, default: (0.05, 0.95)
The percentiles to compute.
unit_ids : list or None
List of unit ids to compute the firing range. If None, all units are used.

Returns
-------
firing_ranges : dict
The firing range for each unit.

Notes
-----
Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino.
"""
sampling_frequency = waveform_extractor.sampling_frequency
bin_size_samples = int(bin_size_s * sampling_frequency)
sorting = waveform_extractor.sorting
if unit_ids is None:
unit_ids = sorting.unit_ids

# for each segment, we compute the firing rate histogram and we concatenate them
firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids}
for segment_index in range(waveform_extractor.get_num_segments()):
num_samples = waveform_extractor.get_num_samples(segment_index)
edges = np.arange(0, num_samples + 1, bin_size_samples)

for unit_id in unit_ids:
spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index)
spike_counts, _ = np.histogram(spike_times, bins=edges)
firing_rates = spike_counts / bin_size_s
firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates))

# finally we compute the percentiles
firing_ranges = {}
for unit_id in unit_ids:
firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile(
firing_rate_histograms[unit_id], percentiles[0]
)

return firing_ranges


_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(0.05, 0.95))


def compute_amplitude_spreads(
waveform_extractor, num_spikes_per_bin=100, amplitude_extension="spike_amplitudes", unit_ids=None
):
"""Calculate spread of spike amplitudes within defined bins of spike events.
The spread is the median relative variance (variance divided by the overall amplitude mean)
computed over bins of `num_spikes_per_bin` spikes.

Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
num_spikes_per_bin : int, default: 50
The number of spikes per bin.
amplitude_extension : str, default: 'spike_amplitudes'
The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'.
unit_ids : list or None
List of unit ids to compute the amplitude spread. If None, all units are used.

Returns
-------
amplitude_spreads : dict
The amplitude spread for each unit.

Notes
-----
Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino.
"""
assert amplitude_extension in (
"spike_amplitudes",
"amplitude_scalings",
), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'"
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector()
num_spikes = sorting.count_num_spikes_per_unit()
if unit_ids is None:
unit_ids = sorting.unit_ids

if waveform_extractor.is_extension(amplitude_extension):
sac = waveform_extractor.load_extension(amplitude_extension)
amps = sac.get_data(outputs="concatenated")
if amplitude_extension == "spike_amplitudes":
amps = np.concatenate(amps)
else:
warnings.warn("")
empty_dict = {unit_id: np.nan for unit_id in unit_ids}
return empty_dict

all_unit_ids = list(sorting.unit_ids)
amplitude_spreads = {}
for unit_id in unit_ids:
amps_unit = amps[spikes["unit_index"] == all_unit_ids.index(unit_id)]
amp_mean = np.abs(np.mean(amps_unit))
if num_spikes[unit_id] < num_spikes_per_bin:
amp_spread = np.std(amps_unit) / amp_mean
else:
amp_spreads = []
for i in range(0, num_spikes[unit_id], num_spikes_per_bin):
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
amp_spreads.append(np.std(amps_unit[i : i + num_spikes_per_bin]) / amp_mean)
amp_spread = np.median(amp_spreads)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
amplitude_spreads[unit_id] = amp_spread

return amplitude_spreads


_default_params["amplitude_spread"] = dict(num_spikes_per_bin=100, amplitude_extension="spike_amplitudes")


def compute_amplitude_cutoffs(
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
compute_amplitude_medians,
compute_drift_metrics,
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_spreads,
)

from .pca_metrics import (
Expand Down Expand Up @@ -40,6 +42,8 @@
"sliding_rp_violation": compute_sliding_rp_violations,
"amplitude_cutoff": compute_amplitude_cutoffs,
"amplitude_median": compute_amplitude_medians,
"amplitude_spread": compute_amplitude_spreads,
"synchrony": compute_synchrony_metrics,
"firing_range": compute_firing_ranges,
"drift": compute_drift_metrics,
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
compute_principal_components,
compute_spike_locations,
compute_spike_amplitudes,
compute_amplitude_scalings,
)

from spikeinterface.qualitymetrics import (
Expand All @@ -31,6 +32,8 @@
compute_drift_metrics,
compute_amplitude_medians,
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_spreads,
)


Expand Down Expand Up @@ -212,6 +215,12 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple):
# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values()))


def test_calculate_firing_range(waveform_extractor_simple):
we = waveform_extractor_simple
firing_ranges = compute_firing_ranges(we)
print(firing_ranges)


def test_calculate_amplitude_cutoff(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
Expand All @@ -234,6 +243,19 @@ def test_calculate_amplitude_median(waveform_extractor_simple):
# assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05)


def test_calculate_amplitude_spread(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
amp_spreads = compute_amplitude_spreads(we, num_spikes_per_bin=20)
print(amp_spreads)

amps_scalings = compute_amplitude_scalings(we)
amp_spreads_scalings = compute_amplitude_spreads(
we, num_spikes_per_bin=20, amplitude_extension="amplitude_scalings"
)
print(amp_spreads_scalings)


def test_calculate_snrs(waveform_extractor_simple):
we = waveform_extractor_simple
snrs = compute_snrs(we)
Expand Down Expand Up @@ -358,4 +380,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# test_calculate_isi_violations(we)
# test_calculate_sliding_rp_violations(we)
# test_calculate_drift_metrics(we)
test_synchrony_metrics(we)
# test_synchrony_metrics(we)
test_calculate_firing_range(we)
test_calculate_amplitude_spread(we)