Skip to content

Commit

Permalink
oups
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 22, 2024
1 parent 5cf6ace commit c16ca72
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
res2 = pd.DataFrame(res2)

for metric_name in res1.columns:
if metric_name != "nn_unit_id":
assert not np.all(np.isnan(res1[metric_name].values))
assert not np.all(np.isnan(res2[metric_name].values))

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.plot(res1[metric_name].values)
# ax.plot(res2[metric_name].values)
# ax.plot(res2[metric_name].values - res1[metric_name].values)
# plt.show()
values1 = res1[metric_name].values
values2 = res1[metric_name].values

np.testing.assert_almost_equal(res1[metric_name].values, res2[metric_name].values, decimal=4)
if metric_name != "nn_unit_id":
assert not np.all(np.isnan(values1))
assert not np.all(np.isnan(values2))

if values1.dtype.kind == "f":
np.testing.assert_almost_equal(values1, values2, decimal=4)
# import matplotlib.pyplot as plt
# fig, axs = plt.subplots(nrows=2, share=True)
# ax =a xs[0]
# ax.plot(res1[metric_name].values)
# ax.plot(res2[metric_name].values)
# ax =a xs[1]
# ax.plot(res2[metric_name].values - res1[metric_name].values)
# plt.show()
else:
assert np.array_equal(values1, values2)


def test_pca_metrics_multi_processing(small_sorting_analyzer):
Expand Down

0 comments on commit c16ca72

Please sign in to comment.