From 14e9a13efeb53a66af24c73444718d5d2d64d866 Mon Sep 17 00:00:00 2001 From: gmartinonQM Date: Fri, 21 Jul 2023 18:28:23 +0200 Subject: [PATCH] FIX linting --- mapie/metrics.py | 6 +++++- mapie/tests/test_metrics.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mapie/metrics.py b/mapie/metrics.py index 85e3d1807..2f048fa40 100644 --- a/mapie/metrics.py +++ b/mapie/metrics.py @@ -868,7 +868,11 @@ def cumulative_differences( y_true = cast(NDArray, column_or_1d(y_true)) y_score = cast(NDArray, column_or_1d(y_score)) n = len(y_true) - y_score_jittered = jitter(y_score, noise_amplitude=noise_amplitude, random_state=random_state) + y_score_jittered = jitter( + y_score, + noise_amplitude=noise_amplitude, + random_state=random_state + ) y_true_sorted, y_score_sorted = sort_xy_by_y(y_true, y_score_jittered) cumulative_differences = np.cumsum(y_true_sorted - y_score_sorted)/n return cumulative_differences diff --git a/mapie/tests/test_metrics.py b/mapie/tests/test_metrics.py index 811eae419..0ccfda75e 100644 --- a/mapie/tests/test_metrics.py +++ b/mapie/tests/test_metrics.py @@ -662,10 +662,11 @@ def test_kuiper_statistic() -> None: ku_stat = kuiper_statistic(y_true, y_score) np.testing.assert_allclose(ku_stat, 5.354395, atol=1e-6) + def test_spiegelhalter_statistic() -> None: """Test that Spiegelhalter's statistics are well computed""" generator = RandomState(1) y_true = generator.choice([0, 1], size=100) y_score = generator.uniform(size=100) sp_stat = spiegelhalter_statistic(y_true, y_score) - np.testing.assert_allclose(sp_stat, 13.906833, atol=1e-6) \ No newline at end of file + np.testing.assert_allclose(sp_stat, 13.906833, atol=1e-6)