diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 8c7c0a2cc3..447d83db52 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -25,9 +25,11 @@ For more details about each metric and it's availability and use within SpikeInt :glob: qualitymetrics/amplitude_cutoff + qualitymetrics/amplitude_cv qualitymetrics/amplitude_median qualitymetrics/d_prime qualitymetrics/drift + qualitymetrics/firing_range qualitymetrics/firing_rate qualitymetrics/isi_violations qualitymetrics/isolation_distance diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/qualitymetrics/amplitude_cutoff.rst index 9f747f8d40..a1e4d85d01 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/qualitymetrics/amplitude_cutoff.rst @@ -21,12 +21,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitudes from all spikes - fraction_missing = qm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") - # fraction_missing is a dict containing the units' IDs as keys, + fraction_missing = sqm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") + # fraction_missing is a dict containing the unit IDs as keys, # and their estimated fraction of missing spikes as values. Reference diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst new file mode 100644 index 0000000000..13117b607c --- /dev/null +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -0,0 +1,55 @@ +Amplitude CV (:code:`amplitude_cv_median`, :code:`amplitude_cv_range`) +====================================================================== + + +Calculation +----------- + +The amplitude CV (coefficient of variation) is a measure of the amplitude variability. +It is computed as the ratio between the standard deviation and the amplitude mean. +To obtain a better estimate of this measure, it is first computed separately for several temporal bins. +Out of these values, the median and the range (percentile distance, by default between the +5th and 95th percentiles) are computed. + +The computation requires either spike amplitudes (see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes()`) +or amplitude scalings (see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings()`) to be pre-computed. + + +Expectation and use +------------------- + +The amplitude CV median is expected to be relatively low for well-isolated units, indicating a "stereotypical" spike shape. + +The amplitude CV range can be high in the presence of noise contamination, due to amplitude outliers like in +the example below. + +.. image:: amplitudes.png + :width: 600 + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + # It is required to run `compute_spike_amplitudes(wvf_extractor)` or + # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(wvf_extractor) + # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, + # and their amplitude_cv metrics as values. + + + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cv_metrics + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index ffc45d1cf6..3ac52560e8 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -20,12 +20,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = qm.compute_amplitude_medians(wvf_extractor) - # amplitude_medians is a dict containing the units' IDs as keys, + amplitude_medians = sqm.compute_amplitude_medians(wvf_extractor) + # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. Reference diff --git a/doc/modules/qualitymetrics/amplitudes.png b/doc/modules/qualitymetrics/amplitudes.png new file mode 100644 index 0000000000..0ee4dd1eda Binary files /dev/null and b/doc/modules/qualitymetrics/amplitudes.png differ diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index abb8c1dc74..e3bd61c580 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -32,9 +32,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm - d_prime = qm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs, all_labels, 0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index 0a852f80af..ae52f7f883 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -40,11 +40,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm + # Make recording, sorting and wvf_extractor object for your data. # It is required to run `compute_spike_locations(wvf_extractor)` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = qm.compute_drift_metrics(wvf_extractor, peak_sign="neg") + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(wvf_extractor, peak_sign="neg") # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst new file mode 100644 index 0000000000..925539e9c6 --- /dev/null +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -0,0 +1,40 @@ +Firing range (:code:`firing_range`) +=================================== + + +Calculation +----------- + +The firing range indicates the dispersion of the firing rate of a unit across the recording. It is computed by +taking the difference between the 95th percentile's firing rate and the 5th percentile's firing rate computed over short time bins (e.g. 10 s). + + + +Expectation and use +------------------- + +Very high levels of firing ranges, outside of a physiological range, might indicate noise contamination. + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + firing_range = sqm.compute_firing_ranges(wvf_extractor) + # firing_range is a dict containing the unit IDs as keys, + # and their firing firing_range as values (in Hz). + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index eddef3e48f..c0e15d7c2e 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -37,11 +37,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = qm.compute_firing_rates(wvf_extractor) - # firing_rate is a dict containing the units' IDs as keys, + firing_rate = sqm.compute_firing_rates(wvf_extractor) + # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). References diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/qualitymetrics/isi_violations.rst index 947e7d4938..725d9b0fd6 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/qualitymetrics/isi_violations.rst @@ -77,11 +77,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) + isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index e4de2248bd..5a420c8ccf 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -23,12 +23,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = qm.compute_presence_ratios(wvf_extractor) - # presence_ratio is a dict containing the units' IDs as keys + presence_ratio = sqm.compute_presence_ratios(wvf_extractor) + # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index 843242c1e8..de68c3a92f 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -27,11 +27,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - contamination = qm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index 288ab60515..b88d3291be 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -41,12 +41,12 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - SNRs = qm.compute_snrs(wvf_extractor) - # SNRs is a dict containing the units' IDs as keys and their SNRs as values. + SNRs = sqm.compute_snrs(wvf_extractor) + # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations --------------------------------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 2f566bf8a7..0750940199 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -27,9 +27,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dd5f857f6..f449b3c31b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,9 +499,8 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): - """ - Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): + """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters @@ -510,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k The waveform extractor object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. Returns ------- @@ -522,16 +523,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k Based on concepts described in [Gruen]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" + assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: + unit_ids = sorting.unit_ids + # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): spikes_in_segment = spikes[segment_index] @@ -539,20 +544,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment - for unit_index in np.arange(len(sorting.unit_ids)): + for unit_id in unit_ids: + unit_index = all_unit_ids.index(unit_id) spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: - synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) + synchrony_counts[synchrony_size][unit_id] += np.count_nonzero(spike_complexity >= synchrony_size) # add counts for this segment synchrony_metrics_dict = { f"sync_spike_{synchrony_size}": { - unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] - for unit_index, unit_id in enumerate(sorting.unit_ids) + unit_id: synchrony_counts[synchrony_size][unit_id] / spike_counts[unit_id] for unit_id in unit_ids } for synchrony_size in synchrony_sizes } @@ -563,7 +568,172 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k return synchrony_metrics -_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) +_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) + + +def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): + """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution + computed in non-overlapping time bins. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + bin_size_s : float, default: 5 + The size of the bin in seconds. + percentiles : tuple, default: (5, 95) + The percentiles to compute. + unit_ids : list or None + List of unit ids to compute the firing range. If None, all units are used. + + Returns + ------- + firing_ranges : dict + The firing range for each unit. + + Notes + ----- + Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. + """ + sampling_frequency = waveform_extractor.sampling_frequency + bin_size_samples = int(bin_size_s * sampling_frequency) + sorting = waveform_extractor.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + + # for each segment, we compute the firing rate histogram and we concatenate them + firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} + for segment_index in range(waveform_extractor.get_num_segments()): + num_samples = waveform_extractor.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) + + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + + # finally we compute the percentiles + firing_ranges = {} + for unit_id in unit_ids: + firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( + firing_rate_histograms[unit_id], percentiles[0] + ) + + return firing_ranges + + +_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) + + +def compute_amplitude_cv_metrics( + waveform_extractor, + average_num_spikes_per_bin=50, + percentiles=(5, 95), + min_num_bins=10, + amplitude_extension="spike_amplitudes", + unit_ids=None, +): + """Calculate coefficient of variation of spike amplitudes within defined temporal bins. + From the distribution of coefficient of variations, both the median and the "range" (the distance between the + percentiles defined by `percentiles` parameter) are returned. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + average_num_spikes_per_bin : int, default: 50 + The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate + of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is + 100, then the temporal bin size will be 100/10 Hz = 10 s. + min_num_bins : int, default: 10 + The minimum number of bins to compute the median and range. If the number of bins is less than this then + the median and range are set to NaN. + amplitude_extension : str, default: 'spike_amplitudes' + The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'. + unit_ids : list or None + List of unit ids to compute the amplitude spread. If None, all units are used. + + Returns + ------- + amplitude_cv_median : dict + The median of the CV + amplitude_cv_range : dict + The range of the CV, computed as the distance between the percentiles. + + Notes + ----- + Designed by Simon Musall and Alessio Buccino. + """ + res = namedtuple("amplitude_cv", ["amplitude_cv_median", "amplitude_cv_range"]) + assert amplitude_extension in ( + "spike_amplitudes", + "amplitude_scalings", + ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" + sorting = waveform_extractor.sorting + total_duration = waveform_extractor.get_total_duration() + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit() + if unit_ids is None: + unit_ids = sorting.unit_ids + + if waveform_extractor.is_extension(amplitude_extension): + sac = waveform_extractor.load_extension(amplitude_extension) + amps = sac.get_data(outputs="concatenated") + if amplitude_extension == "spike_amplitudes": + amps = np.concatenate(amps) + else: + warnings.warn("") + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + return empty_dict + + # precompute segment slice + segment_slices = [] + for segment_index in range(waveform_extractor.get_num_segments()): + i0 = np.searchsorted(spikes["segment_index"], segment_index) + i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append(slice(i0, i1)) + + all_unit_ids = list(sorting.unit_ids) + amplitude_cv_medians, amplitude_cv_ranges = {}, {} + for unit_id in unit_ids: + firing_rate = num_spikes[unit_id] / total_duration + temporal_bin_size_samples = int( + (average_num_spikes_per_bin / firing_rate) * waveform_extractor.sampling_frequency + ) + + amp_spreads = [] + # bins and amplitude means are computed for each segment + for segment_index in range(waveform_extractor.get_num_segments()): + sample_bin_edges = np.arange( + 0, waveform_extractor.get_num_samples(segment_index) + 1, temporal_bin_size_samples + ) + spikes_in_segment = spikes[segment_slices[segment_index]] + amps_in_segment = amps[segment_slices[segment_index]] + unit_mask = spikes_in_segment["unit_index"] == all_unit_ids.index(unit_id) + spike_indices_unit = spikes_in_segment["sample_index"][unit_mask] + amps_unit = amps_in_segment[unit_mask] + amp_mean = np.abs(np.mean(amps_unit)) + for t0, t1 in zip(sample_bin_edges[:-1], sample_bin_edges[1:]): + i0 = np.searchsorted(spike_indices_unit, t0) + i1 = np.searchsorted(spike_indices_unit, t1) + amp_spreads.append(np.std(amps_unit[i0:i1]) / amp_mean) + + if len(amp_spreads) < min_num_bins: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + else: + amplitude_cv_medians[unit_id] = np.median(amp_spreads) + amplitude_cv_ranges[unit_id] = np.percentile(amp_spreads, percentiles[1]) - np.percentile( + amp_spreads, percentiles[0] + ) + + return res(amplitude_cv_medians, amplitude_cv_ranges) + + +_default_params["amplitude_cv"] = dict( + average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" +) def compute_amplitude_cutoffs( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 90dbb47a3a..97f14ec6f4 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -12,6 +12,8 @@ compute_amplitude_medians, compute_drift_metrics, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) from .pca_metrics import ( @@ -40,6 +42,8 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "amplitude_cv": compute_amplitude_cv_metrics, "synchrony": compute_synchrony_metrics, + "firing_range": compute_firing_ranges, "drift": compute_drift_metrics, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d927d64c4f..2d63a06b17 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,6 +12,7 @@ compute_principal_components, compute_spike_locations, compute_spike_amplitudes, + compute_amplitude_scalings, ) from spikeinterface.qualitymetrics import ( @@ -31,6 +32,8 @@ compute_drift_metrics, compute_amplitude_medians, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) @@ -212,6 +215,12 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) +def test_calculate_firing_range(waveform_extractor_simple): + we = waveform_extractor_simple + firing_ranges = compute_firing_ranges(we) + print(firing_ranges) + + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) @@ -234,6 +243,24 @@ def test_calculate_amplitude_median(waveform_extractor_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) +def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + spike_amps = compute_spike_amplitudes(we) + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) + print(amp_cv_median) + print(amp_cv_range) + + amps_scalings = compute_amplitude_scalings(we) + amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( + we, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + print(amp_cv_median_scalings) + print(amp_cv_range_scalings) + + def test_calculate_snrs(waveform_extractor_simple): we = waveform_extractor_simple snrs = compute_snrs(we) @@ -358,4 +385,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_isi_violations(we) # test_calculate_sliding_rp_violations(we) # test_calculate_drift_metrics(we) - test_synchrony_metrics(we) + # test_synchrony_metrics(we) + test_calculate_firing_range(we) + test_calculate_amplitude_cv_metrics(we)