Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype of quality metrics before and after merging #3497

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ def _merge_extension_data(
all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

# this creates a new metrics dictionary, but the dtype for everything will be
# object
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)
# we can iterate through the columns and convert them back to numbers with
# pandas.to_numeric. coerce allows us to keep the nan values.
for column in metrics.columns:
metrics[column] = pd.to_numeric(metrics[column], errors="coerce")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok for me.
pandas behavior is becoming quite cryptic for me.
using old_metrics[col].dtype could be also used no ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I agree Pandas is making their own dtypes like NADType which doesn't play nicely with numpy in my scipts I tend to just query based on numpy stuff). So I don't know for sure. I could test that later. Although for me I would prefer to coerce everything to numpy types since that's what I'm used to. None of my tables are big enough that I worry about dtype inefficiency stuff that Pandas has been working on with the new backend.

zm711 marked this conversation as resolved.
Show resolved Hide resolved

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ def test_compute_quality_metrics(sorting_analyzer_simple):
assert "isolation_distance" in metrics.columns


def test_merging_quality_metrics(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple

metrics = compute_quality_metrics(
sorting_analyzer,
metric_names=None,
qm_params=dict(isi_violation=dict(isi_threshold_ms=2)),
skip_pc_metrics=False,
seed=2205,
)

# sorting_analyzer_simple has ten units
new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]])

new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data()

# we should copy over the metrics after merge
for column in metrics.columns:
assert column in new_metrics.columns

# 10 units vs 9 units
assert len(metrics.index) > len(new_metrics.index)

# dtype should be fine after merge but is cast from Float64->float64
assert np.float64 == new_metrics["snr"].dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a test on int coercion if we end up using the suggestion here: https://github.com/SpikeInterface/spikeinterface/pull/3497/files#r1827487180



def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple
Expand Down