From 25373b649280f7163cde701d7f49e91cbcca97a5 Mon Sep 17 00:00:00 2001 From: hclark94 Date: Thu, 30 May 2024 12:29:59 +0100 Subject: [PATCH 1/2] standardise qualitymetrics docstrings to numpydocs standard --- .../qualitymetrics/misc_metrics.py | 200 ++++++++++-------- .../qualitymetrics/pca_metrics.py | 57 ++--- .../quality_metric_calculator.py | 13 +- src/spikeinterface/qualitymetrics/utils.py | 17 +- 4 files changed, 156 insertions(+), 131 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 6b77e23c35..fe784a36ec 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -37,12 +37,13 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): - """Compute the number of spike across segments. + """ + Compute the number of spike across segments. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. @@ -69,12 +70,13 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): - """Compute the firing rate across segments. + """ + Compute the firing rate across segments. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. @@ -97,17 +99,18 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): - """Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. + """ + Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, - presence_ratio is set to NaN - mean_fr_ratio_thresh: float, default: 0 - The unit is considered active in a bin if its firing rate during that bin + presence_ratio is set to NaN. + mean_fr_ratio_thresh : float, default: 0 + The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. unit_ids : list or None The list of unit ids to compute the presence ratio. If None, all units are used. @@ -176,24 +179,24 @@ def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio ) -def compute_snrs( - sorting_analyzer, +def compute_snrs(sorting_analyzer, peak_sign: str = "neg", peak_mode: str = "extremum", unit_ids=None, ): - """Compute signal to noise ratio. + """ + Compute signal to noise ratio. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. - peak_mode: "extremum" | "at_index", default: "extremum" + peak_mode : "extremum" | "at_index", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima - At_index takes the value at t=sorting_analyzer.nbefore + At_index takes the value at t=sorting_analyzer.nbefore. unit_ids : list or None The list of unit ids to compute the SNR. If None, all units are used. @@ -232,7 +235,8 @@ def compute_snrs( def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): - """Calculate Inter-Spike Interval (ISI) violations. + """ + Calculate Inter-Spike Interval (ISI) violations. It computes several metrics related to isi violations: * isi_violations_ratio: the relative firing rate of the hypothetical neurons that are @@ -242,13 +246,13 @@ def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, Parameters ---------- sorting_analyzer : SortingAnalyzer - The SortingAnalyzer object + The SortingAnalyzer object. isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. - This is the biophysical refractory period + This is the biophysical refractory period. min_isi_ms : float, default: 0 Minimum possible inter-spike interval, in ms. - This is the artificial refractory period enforced + This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. unit_ids : list or None List of unit ids to compute the ISI violations. If None, all units are used. @@ -315,7 +319,8 @@ def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, def compute_refrac_period_violations( sorting_analyzer, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None ): - """Calculates the number of refractory period violations. + """ + Calculate the number of refractory period violations. This is similar (but slightly different) to the ISI violations. The key difference being that the violations are not only computed on consecutive spikes. @@ -325,7 +330,7 @@ def compute_refrac_period_violations( Parameters ---------- sorting_analyzer : SortingAnalyzer - The SortingAnalyzer object + The SortingAnalyzer object. refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 @@ -348,7 +353,6 @@ def compute_refrac_period_violations( References ---------- Based on metrics described in [Llobet]_ - """ res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) @@ -411,34 +415,35 @@ def compute_sliding_rp_violations( contamination_values=None, unit_ids=None, ): - """Compute sliding refractory period violations, a metric developed by IBL which computes + """ + Compute sliding refractory period violations, a metric developed by IBL which computes contamination by using a sliding refractory period. This metric computes the minimum contamination with at least 90% confidence. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. bin_size_ms : float, default: 0.25 - The size of binning for the autocorrelogram in ms + The size of binning for the autocorrelogram in ms. window_size_s : float, default: 1 - Window in seconds to compute correlogram + Window in seconds to compute correlogram. exclude_ref_period_below_ms : float, default: 0.5 - Refractory periods below this value are excluded + Refractory periods below this value are excluded. max_ref_period_ms : float, default: 10 - Maximum refractory period to test in ms + Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None - The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5) + The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. Returns ------- contamination : dict of floats - The minimum contamination at 90% confidence + The minimum contamination at 90% confidence. References ---------- @@ -497,7 +502,8 @@ def compute_sliding_rp_violations( def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): - """Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes` + """ + Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. Parameters ---------- @@ -505,7 +511,7 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). synchrony_sizes : numpy array The synchrony sizes to compute. Should be pre-sorted. - unit_ids : list or None, default: None + all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. Returns @@ -541,13 +547,14 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): - """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of + """ + Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. unit_ids : list or None, default: None @@ -604,12 +611,13 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): - """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution + """ + Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object bin_size_s : float, default: 5 The size of the bin in seconds. @@ -675,18 +683,21 @@ def compute_amplitude_cv_metrics( amplitude_extension="spike_amplitudes", unit_ids=None, ): - """Calculate coefficient of variation of spike amplitudes within defined temporal bins. + """ + Calculate coefficient of variation of spike amplitudes within defined temporal bins. From the distribution of coefficient of variations, both the median and the "range" (the distance between the percentiles defined by `percentiles` parameter) are returned. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is 100, then the temporal bin size will be 100/10 Hz = 10 s. + percentiles : tuple, default: (5, 95) + The percentile values from which to calculate the range. min_num_bins : int, default: 10 The minimum number of bins to compute the median and range. If the number of bins is less than this then the median and range are set to NaN. @@ -810,12 +821,13 @@ def compute_amplitude_cutoffs( amplitudes_bins_min_ratio=5, unit_ids=None, ): - """Calculate approximate fraction of spikes missing from a distribution of amplitudes. + """ + Calculate approximate fraction of spikes missing from a distribution of amplitudes. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. num_histogram_bins : int, default: 100 @@ -893,12 +905,13 @@ def compute_amplitude_cutoffs( def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): - """Compute median of the amplitude distributions (in absolute value). + """ + Compute median of the amplitude distributions (in absolute value). Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. unit_ids : list or None @@ -945,7 +958,8 @@ def compute_drift_metrics( return_positions=False, unit_ids=None, ): - """Compute drifts metrics using estimated spike locations. + """ + Compute drifts metrics using estimated spike locations. Over the duration of the recording, the drift signal for each unit is calculated as the median position in an interval with respect to the overall median positions over the entire duration (reference position). @@ -960,36 +974,36 @@ def compute_drift_metrics( Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. interval_s : int, default: 60 - Interval length is seconds for computing spike depth + Interval length is seconds for computing spike depth. min_spikes_per_interval : int, default: 100 - Minimum number of spikes for computing depth in an interval + Minimum number of spikes for computing depth in an interval. direction : "x" | "y" | "z", default: "y" - The direction along which drift metrics are estimated + The direction along which drift metrics are estimated. min_fraction_valid_intervals : float, default: 0.5 The fraction of valid (not NaN) position estimates to estimate drifts. E.g., if 0.5 at least 50% of estimated positions in the intervals need to be valid, - otherwise drift metrics are set to None + otherwise drift metrics are set to None. min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. return_positions : bool, default: False - If True, median positions are returned (for debugging) + If True, median positions are returned (for debugging). unit_ids : list or None, default: None - List of unit ids to compute the drift metrics. If None, all units are used + List of unit ids to compute the drift metrics. If None, all units are used. Returns ------- drift_ptp : dict - The drift signal peak-to-peak in um + The drift signal peak-to-peak in um. drift_std : dict - The drift signal standard deviation in um + The drift signal standard deviation in um. drift_mad : dict - The drift signal median absolute deviation in um + The drift signal median absolute deviation in um. median_positions : np.array (optional) - The median positions of each unit over time (only returned if return_positions=True) + The median positions of each unit over time (only returned if return_positions=True). Notes ----- @@ -1112,26 +1126,27 @@ def compute_drift_metrics( ### LOW-LEVEL FUNCTIONS ### def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): - """Calculate the presence ratio for a single unit + """ + Calculate the presence ratio for a single unit. Parameters ---------- spike_train : np.ndarray - Spike times for this unit, in samples + Spike times for this unit, in samples. total_length : int - Total length of the recording in samples - bin_edges : np.array + Total length of the recording in samples. + bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). - num_bin_edges : int, default: 101 + num_bin_edges : int, optional The number of bins edges to use to compute the presence ratio. (mutually exclusive with bin_edges). - bin_n_spikes_thres: int, default: 0 - Minimum number of spikes within a bin to consider the unit active + bin_n_spikes_thres : int, default: 0 + Minimum number of spikes within a bin to consider the unit active. Returns ------- presence_ratio : float - The presence ratio for one unit + The presence ratio for one unit. """ assert bin_edges is not None or num_bin_edges is not None, "Use either bin_edges or num_bin_edges" @@ -1147,19 +1162,20 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_isi_s=0): - """Calculate Inter-Spike Interval (ISI) violations. + """ + Calculate Inter-Spike Interval (ISI) violations. See compute_isi_violations for additional documentation Parameters ---------- spike_trains : list of np.ndarrays - The spike times for each recording segment for one unit, in seconds + The spike times for each recording segment for one unit, in seconds. total_duration_s : float - The total duration of the recording (in seconds) + The total duration of the recording (in seconds). isi_threshold_s : float, default: 0.0015 Threshold for classifying adjacent spikes as an ISI violation, in seconds. - This is the biophysical refractory period + This is the biophysical refractory period. min_isi_s : float, default: 0 Minimum possible inter-spike interval, in seconds. This is the artificial refractory period enforced @@ -1168,12 +1184,12 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i Returns ------- isi_violations_ratio : float - The isi violation ratio described in [1] + The isi violation ratio described in [1]. isi_violations_rate : float Rate of contaminating spikes as a fraction of overall rate. Higher values indicate more contamination. isi_violation_count : int - Number of violations + Number of violations. """ num_violations = 0 @@ -1201,8 +1217,8 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5): - """Calculate approximate fraction of spikes missing from a distribution of amplitudes. - + """ + Calculate approximate fraction of spikes missing from a distribution of amplitudes. See compute_amplitude_cutoffs for additional documentation @@ -1210,8 +1226,6 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val ---------- amplitudes : ndarray_like The amplitudes (in uV) of the spikes for one unit. - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the template to compute best channels. num_histogram_bins : int, default: 500 The number of bins to use to compute the amplitude histogram. histogram_smoothing_value : int, default: 3 @@ -1275,21 +1289,21 @@ def slidingRP_violations( Parameters ---------- spike_samples : ndarray_like or list (for multi-segment) - The spike times in samples + The spike times in samples. sample_rate : float - The acquisition sampling rate + The acquisition sampling rate. bin_size_ms : float The size (in ms) of binning for the autocorrelogram. window_size_s : float, default: 1 - Window in seconds to compute correlogram + Window in seconds to compute correlogram. exclude_ref_period_below_ms : float, default: 0.5 Refractory periods below this value are excluded max_ref_period_ms : float, default: 10 - Maximum refractory period to test in ms + Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None - The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100 + The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100. return_conf_matrix : bool, default: False - If True, the confidence matrix (n_contaminations, n_ref_periods) is returned + If True, the confidence matrix (n_contaminations, n_ref_periods) is returned. Code adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/master/python/slidingRP/metrics.py#L166 @@ -1297,7 +1311,7 @@ def slidingRP_violations( Returns ------- min_cont_with_90_confidence : dict of floats - The minimum contamination with confidence > 90% + The minimum contamination with confidence > 90%. """ if contamination_values is None: contamination_values = np.arange(0.5, 35, 0.5) / 100 # vector of contamination values to test @@ -1413,13 +1427,13 @@ def compute_sd_ratio( Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. - correct_for_drift: bool, default: True + correct_for_drift : bool, default: True If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift. - correct_for_template_itself: bool, default: True + correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). unit_ids : list or None, default: None diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ac8e486b1f..8af8ac0acb 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -60,12 +60,13 @@ def get_quality_pca_metric_list(): def calculate_pc_metrics( sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): - """Calculate principal component derived metrics. + """ + Calculate principal component derived metrics. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. @@ -185,7 +186,8 @@ def calculate_pc_metrics( def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): - """Calculates isolation distance and L-ratio (metrics computed from Mahalanobis distance) + """ + Calculate isolation distance and L-ratio (metrics computed from Mahalanobis distance). Parameters ---------- @@ -240,7 +242,8 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): def lda_metrics(all_pcs, all_labels, this_unit_id): - """Calculates d-prime based on Linear Discriminant Analysis. + """ + Calculate d-prime based on Linear Discriminant Analysis. Parameters ---------- @@ -282,7 +285,7 @@ def lda_metrics(all_pcs, all_labels, this_unit_id): def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_neighbors): """ - Calculates unit contamination based on NearestNeighbors search in PCA space. + Calculate unit contamination based on NearestNeighbors search in PCA space. Parameters ---------- @@ -365,18 +368,19 @@ def nearest_neighbors_isolation( min_spatial_overlap: float = 0.5, seed=None, ): - """Calculates unit isolation based on NearestNeighbors search in PCA space. + """ + Calculate unit isolation based on NearestNeighbors search in PCA space. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. this_unit_id : int | str The ID for the unit to calculate these metrics for. - n_spikes_all_units: dict, default: None + n_spikes_all_units : dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. Recomputed if None. - fr_all_units: dict, default: None + fr_all_units : dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. Recomputed if None. max_spikes : int, default: 1000 @@ -395,12 +399,12 @@ def nearest_neighbors_isolation( The number of PC components to use to project the snippets to. radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. - peak_sign: "neg" | "pos" | "both", default: "neg" + peak_sign : "neg" | "pos" | "both", default: "neg" The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer is not sparse already. min_spatial_overlap : float, default: 100 In case sorting_analyzer is sparse, other units are selected if they share at least - `min_spatial_overlap` times `n_target_unit_channels` with the target unit + `min_spatial_overlap` times `n_target_unit_channels` with the target unit. seed : int, default: None Seed for random subsampling of spikes. @@ -410,7 +414,7 @@ def nearest_neighbors_isolation( The calculation nearest neighbor isolation metric for `this_unit_id`. If the unit has fewer than `min_spikes`, returns numpy.NaN instead. nn_unit_id : np.int16 - Id of the "nearest neighbor" unit (unit with lowest isolation score from `this_unit_id`) + Id of the "nearest neighbor" unit (unit with lowest isolation score from `this_unit_id`). Notes ----- @@ -578,18 +582,19 @@ def nearest_neighbors_noise_overlap( peak_sign: str = "neg", seed=None, ): - """Calculates unit noise overlap based on NearestNeighbors search in PCA space. + """ + Calculate unit noise overlap based on NearestNeighbors search in PCA space. Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. this_unit_id : int | str The ID of the unit to calculate this metric on. - n_spikes_all_units: dict, default: None + n_spikes_all_units : dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. Recomputed if None. - fr_all_units: dict, default: None + fr_all_units : dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. Recomputed if None. max_spikes : int, default: 1000 @@ -606,7 +611,7 @@ def nearest_neighbors_noise_overlap( The number of PC components to use to project the snippets to. radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. - peak_sign: "neg" | "pos" | "both", default: "neg" + peak_sign : "neg" | "pos" | "both", default: "neg" The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer is not sparse already. seed : int, default: 0 @@ -740,7 +745,8 @@ def nearest_neighbors_noise_overlap( def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): - """Calculates the simplified silhouette score for each cluster. The value ranges + """ + Calculate the simplified silhouette score for each cluster. The value ranges from -1 (bad clustering) to 1 (good clustering). The simplified silhoutte score utilizes the centroids for distance calculations rather than pairwise calculations. @@ -756,7 +762,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): Returns ------- unit_silhouette_score : float - Simplified Silhouette Score for this unit + Simplified Silhouette Score for this unit. References ---------- @@ -789,7 +795,8 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): def silhouette_score(all_pcs, all_labels, this_unit_id): - """Calculates the silhouette score which is a marker of cluster quality ranging from + """ + Calculate the silhouette score which is a marker of cluster quality ranging from -1 (bad clustering) to 1 (good clustering). Distances are all calculated as pairwise comparisons of all data points. @@ -805,7 +812,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): Returns ------- unit_silhouette_score : float - Silhouette Score for this unit + Silhouette Score for this unit. References ---------- @@ -845,7 +852,7 @@ def _subtract_clip_component(clip1, component): def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): """ - Computes the isolation score used for nn_isolation and nn_noise_overlap + Compute the isolation score used for nn_isolation and nn_noise_overlap. Parameters ---------- diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 1dab1c602c..c286bb04a6 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -23,8 +23,8 @@ class ComputeQualityMetrics(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. qm_params : dict or None @@ -36,7 +36,7 @@ class ComputeQualityMetrics(AnalyzerExtension): Returns ------- metrics: pandas.DataFrame - Data frame with the computed metrics + Data frame with the computed metrics. Notes ----- @@ -171,13 +171,16 @@ def _get_data(self): def get_quality_metric_list(): - """Get a list of the available quality metrics.""" + """ + Return a list of the available quality metrics. + """ return deepcopy(list(_misc_metric_name_to_func.keys())) def get_default_qm_params(): - """Return default dictionary of quality metrics parameters. + """ + Return default dictionary of quality metrics parameters. Returns ------- diff --git a/src/spikeinterface/qualitymetrics/utils.py b/src/spikeinterface/qualitymetrics/utils.py index 553719bba6..91530e5304 100644 --- a/src/spikeinterface/qualitymetrics/utils.py +++ b/src/spikeinterface/qualitymetrics/utils.py @@ -5,22 +5,23 @@ def create_ground_truth_pc_distributions(center_locations, total_points): - """Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. Parameters ---------- center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit + Mean of the multivariate gaussian at each channel for each unit. total_points : array-like - Number of points in each unit distribution + Number of points in each unit distribution. Returns ------- - numpy.ndarray - PC scores for each point - numpy.array - Labels for each point + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. """ np.random.seed(0) From 597ef7d2006fc1232278bf608ee20ac7f6018e02 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 11:56:14 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/qualitymetrics/misc_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index fe784a36ec..91117f1c08 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -179,7 +179,8 @@ def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio ) -def compute_snrs(sorting_analyzer, +def compute_snrs( + sorting_analyzer, peak_sign: str = "neg", peak_mode: str = "extremum", unit_ids=None,