Skip to content

Commit

Permalink
FIX linting
Browse files Browse the repository at this point in the history
  • Loading branch information
gmartinonQM committed Jul 21, 2023
1 parent b5ef581 commit 14e9a13
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mapie/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,11 @@ def cumulative_differences(
y_true = cast(NDArray, column_or_1d(y_true))
y_score = cast(NDArray, column_or_1d(y_score))
n = len(y_true)
y_score_jittered = jitter(y_score, noise_amplitude=noise_amplitude, random_state=random_state)
y_score_jittered = jitter(
y_score,
noise_amplitude=noise_amplitude,
random_state=random_state
)
y_true_sorted, y_score_sorted = sort_xy_by_y(y_true, y_score_jittered)
cumulative_differences = np.cumsum(y_true_sorted - y_score_sorted)/n
return cumulative_differences
Expand Down
3 changes: 2 additions & 1 deletion mapie/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,11 @@ def test_kuiper_statistic() -> None:
ku_stat = kuiper_statistic(y_true, y_score)
np.testing.assert_allclose(ku_stat, 5.354395, atol=1e-6)


def test_spiegelhalter_statistic() -> None:
"""Test that Spiegelhalter's statistics are well computed"""
generator = RandomState(1)
y_true = generator.choice([0, 1], size=100)
y_score = generator.uniform(size=100)
sp_stat = spiegelhalter_statistic(y_true, y_score)
np.testing.assert_allclose(sp_stat, 13.906833, atol=1e-6)
np.testing.assert_allclose(sp_stat, 13.906833, atol=1e-6)

0 comments on commit 14e9a13

Please sign in to comment.