Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype of quality metrics before and after merging #3497

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
25 changes: 24 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_misc_metric_name_to_func,
_possible_pc_metric_names,
compute_name_to_column_names,
column_name_to_column_dtype,
)
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params
Expand Down Expand Up @@ -125,13 +126,20 @@ def _merge_extension_data(
all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

# this creates a new metrics dictionary, but the dtype for everything will be
# object. So we will need to fix this later after computing metrics
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

# we need to fix the dtypes after we compute everything because we have nans
# we can iterate through the columns and convert them back to the dtype
# of the original quality dataframe.
for column in old_metrics.columns:
metrics[column] = metrics[column].astype(old_metrics[column].dtype)

new_data = dict(metrics=metrics)
return new_data

Expand Down Expand Up @@ -214,10 +222,25 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
# add NaN for empty units
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan
# num_spikes is an int and should be 0
if "num_spikes" in metrics.columns:
metrics.loc[empty_unit_ids, ["num_spikes"]] = 0

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()

# we do this because the convert_dtypes infers the wrong types sometimes.
# the actual types for columns can be found in column_name_to_column_dtype dictionary.
for column in metrics.columns:
# we have one issue where the name of the columns for synchrony are named based on
# what the user has input as arguments so we need a way to handle this separately
# everything else should be handled with the column name.
if "sync" in column:
metrics[column] = metrics[column].astype(column_name_to_column_dtype["sync"])
else:
metrics[column] = metrics[column].astype(column_name_to_column_dtype[column])

return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down
39 changes: 38 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
"amplitude_cutoff": ["amplitude_cutoff"],
"amplitude_median": ["amplitude_median"],
"amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"],
"synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"],
"synchrony": [
"sync_spike_2",
"sync_spike_4",
"sync_spike_8",
], # we probably shouldn't hard code this. This is determined by the arguments in the function...
"firing_range": ["firing_range"],
"drift": ["drift_ptp", "drift_std", "drift_mad"],
"sd_ratio": ["sd_ratio"],
Expand All @@ -79,3 +83,36 @@
"silhouette": ["silhouette"],
"silhouette_full": ["silhouette_full"],
}

# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them
column_name_to_column_dtype = {
"num_spikes": int,
"firing_rate": float,
"presence_ratio": float,
"snr": float,
"isi_violations_ratio": float,
"isi_violations_count": float,
"rp_violations": float,
"rp_contamination": float,
"sliding_rp_violation": float,
"amplitude_cutoff": float,
"amplitude_median": float,
"amplitude_cv_median": float,
"amplitude_cv_range": float,
"sync": float,
"firing_range": float,
"drift_ptp": float,
"drift_std": float,
"drift_mad": float,
"sd_ratio": float,
"isolation_distance": float,
"l_ratio": float,
"d_prime": float,
"nn_hit_rate": float,
"nn_miss_rate": float,
"nn_isolation": float,
"nn_unit_id": float,
"nn_noise_overlap": float,
"silhouette": float,
"silhouette_full": float,
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple):
assert "isolation_distance" in metrics.columns


def test_merging_quality_metrics(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple

metrics = compute_quality_metrics(
sorting_analyzer,
metric_names=None,
qm_params=dict(isi_violation=dict(isi_threshold_ms=2)),
skip_pc_metrics=False,
seed=2205,
)

# sorting_analyzer_simple has ten units
new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]])

new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data()

# we should copy over the metrics after merge
for column in metrics.columns:
assert column in new_metrics.columns
# should copy dtype too
assert metrics[column].dtype == new_metrics[column].dtype

# 10 units vs 9 units
assert len(metrics.index) > len(new_metrics.index)


def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple
Expand Down Expand Up @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple):
seed=2205,
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
# num_spikes are ints not nans so we confirm empty units are nans for everything except
# num_spikes which should be 0
nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"]
for empty_unit_ids in sorting_empty.get_empty_unit_ids():
from pandas import isnull

assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))
assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values))
if "num_spikes" in metrics_empty.columns:
assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down