Skip to content

Commit

Permalink
refactor: quality metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
MilagrosMarin committed May 1, 2024
1 parent 9299142 commit 6f6acd4
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,19 +1218,20 @@ class Cluster(dj.Part):
Attributes:
QualityMetrics (foreign key): QualityMetrics primary key.
CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
firing_rate (float): Firing rate of the unit.
firing_rate (float): Firing rate for a unit as the average number of spikes within the recording per second
snr (float): Signal-to-noise ratio for a unit.
presence_ratio (float): Fraction of time where spikes are present.
isi_violation (float): rate of ISI violation as a fraction of overall rate.
number_violation (int): Total ISI violations.
amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
amplitude_cutoff (float): Estimate of the fraction of false negatives during intervals (missed rate) based on amplitude histogram
isolation_distance (float): Distance to nearest cluster.
l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
d_prime (float): Classification accuracy based on LDA.
nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
silhouette_core (float): Maximum change in spike depth throughout recording.
cumulative_drift (float): Cumulative change in spike depth throughout recording.
silhouette_core (float): Standard metric for cluster overlap
max_drift (float): Peak-to-peak of the drift signal for each unit
cumulative_drift (float): Median absolute deviation of the drift signal for each unit.
contamination_rate (float): Frequency of spikes in the refractory period.
"""

Expand All @@ -1239,20 +1240,20 @@ class Cluster(dj.Part):
-> master
-> CuratedClustering.Unit
---
firing_rate=null: float # (Hz) firing rate for a unit
firing_rate=null: float # (Hz) firing rate for a unit as the average number of spikes within the recording per second
snr=null: float # signal-to-noise ratio for a unit
presence_ratio=null: float # fraction of time in which spikes are present
isi_violation=null: float # rate of ISI violation as a fraction of overall rate
number_violation=null: int # total number of ISI violations
amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
amplitude_cutoff=null: float # estimate of the fraction of false negatives during intervals (missed rate) based on amplitude histogram
isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
l_ratio=null: float #
d_prime=null: float # Classification accuracy based on LDA
nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
silhouette_score=null: float # Standard metric for cluster overlap
max_drift=null: float # Maximum change in spike depth throughout recording
cumulative_drift=null: float # Cumulative change in spike depth throughout recording
silhouette_core (float): Standard metric for cluster overlap
max_drift=null: float # Peak-to-peak of the drift signal for each unit
cumulative_drift (float): Median absolute deviation of the drift signal for each unit
contamination_rate=null: float #
"""

Expand Down Expand Up @@ -1295,20 +1296,32 @@ def make(self, key):
kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)

metric_fp = kilosort_dir / "metrics.csv"
rename_dict = {
"isi_viol": "isi_violation",
"num_viol": "number_violation",
"contam_rate": "contamination_rate",
}

if not metric_fp.exists():
raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")

metrics_df = pd.read_csv(metric_fp)

if "cluster_id" in metrics_df.columns:
metrics_df.set_index("cluster_id", inplace=True)
else:
metrics_df.rename(
columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
)

metrics_df.set_index("cluster_id", inplace=True)
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
metrics_df.columns = metrics_df.columns.str.lower()

rename_dict = {
"isi_violations_ratio": "isi_violation",
"isi_violations_count": "number_violation",
"silhouette": "silhouette_score",
"rp_contamination": "contamination_rate",
"drift_ptp": "max_drift",
"drift_mad": "cumulative_drift",
}
metrics_df.rename(columns=rename_dict, inplace=True)

metrics_list = [
dict(metrics_df.loc[unit_key["unit"]], **unit_key)
for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
Expand Down

0 comments on commit 6f6acd4

Please sign in to comment.