From 6f6acd419c62b178e79992b68d5e1ee56c5e9fbf Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 1 May 2024 20:51:26 +0200 Subject: [PATCH] refactor: quality metrics --- element_array_ephys/ephys_no_curation.py | 43 +++++++++++++++--------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 856ddfeb..3a245d4b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1218,19 +1218,20 @@ class Cluster(dj.Part): Attributes: QualityMetrics (foreign key): QualityMetrics primary key. CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. + firing_rate (float): Firing rate for a unit as the average number of spikes within the recording per second snr (float): Signal-to-noise ratio for a unit. presence_ratio (float): Fraction of time where spikes are present. isi_violation (float): rate of ISI violation as a fraction of overall rate. number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. + amplitude_cutoff (float): Estimate of the fraction of false negatives during intervals (missed rate) based on amplitude histogram isolation_distance (float): Distance to nearest cluster. l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. d_prime (float): Classification accuracy based on LDA. nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. + silhouette_core (float): Standard metric for cluster overlap + max_drift (float): Peak-to-peak of the drift signal for each unit + cumulative_drift (float): Median absolute deviation of the drift signal for each unit. contamination_rate (float): Frequency of spikes in the refractory period. """ @@ -1239,20 +1240,20 @@ class Cluster(dj.Part): -> master -> CuratedClustering.Unit --- - firing_rate=null: float # (Hz) firing rate for a unit + firing_rate=null: float # (Hz) firing rate for a unit as the average number of spikes within the recording per second snr=null: float # signal-to-noise ratio for a unit presence_ratio=null: float # fraction of time in which spikes are present isi_violation=null: float # rate of ISI violation as a fraction of overall rate number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram + amplitude_cutoff=null: float # estimate of the fraction of false negatives during intervals (missed rate) based on amplitude histogram isolation_distance=null: float # distance to nearest cluster in Mahalanobis space l_ratio=null: float # d_prime=null: float # Classification accuracy based on LDA nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording + silhouette_core (float): Standard metric for cluster overlap + max_drift=null: float # Peak-to-peak of the drift signal for each unit + cumulative_drift (float): Median absolute deviation of the drift signal for each unit contamination_rate=null: float # """ @@ -1295,20 +1296,32 @@ def make(self, key): kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") metrics_df = pd.read_csv(metric_fp) + + if "cluster_id" in metrics_df.columns: + metrics_df.set_index("cluster_id", inplace=True) + else: + metrics_df.rename( + columns={metrics_df.columns[0]: "cluster_id"}, inplace=True + ) + metrics_df.set_index("cluster_id", inplace=True) metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) metrics_df.columns = metrics_df.columns.str.lower() + + rename_dict = { + "isi_violations_ratio": "isi_violation", + "isi_violations_count": "number_violation", + "silhouette": "silhouette_score", + "rp_contamination": "contamination_rate", + "drift_ptp": "max_drift", + "drift_mad": "cumulative_drift", + } metrics_df.rename(columns=rename_dict, inplace=True) + metrics_list = [ dict(metrics_df.loc[unit_key["unit"]], **unit_key) for unit_key in (CuratedClustering.Unit & key).fetch("KEY")