Skip to content

Commit

Permalink
ENH: add random state to test and lint change
Browse files Browse the repository at this point in the history
  • Loading branch information
SZiane committed Jul 10, 2023
1 parent 30dcb4b commit 5a9ffe4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mapie/control_risk/ltt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _ltt_procedure(
Apply the Learn-Then-Test procedure for risk control.
This procedure is called in ``MapieMultiLabelClassifier``
if ``metric=precision``.
This will apply learn then test procedure for
This will apply learn then test procedure for
precision control.
Note that we will do a multiple test for ``r_hat`` that are
less than level ``alpha_np``.
Expand Down
4 changes: 1 addition & 3 deletions mapie/multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import warnings
from typing import Iterable, Optional, Sequence, Tuple, Union, cast

import itertools

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -99,7 +97,7 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
lambdas_star: ArrayLike of shape (n_lambdas)
Optimal threshold for a given alpha.
valid_index: List[List[Any]]
List of list of all index that satisfy fwer controlling
for learn then test procedure. This attribute is compute
Expand Down
12 changes: 10 additions & 2 deletions mapie/tests/test_control_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@

wrong_delta = None

random_state = 42
prng = np.random.RandomState(random_state)
y_1 = prng.random(51)
y_2 = prng.random((51, 5))
y_3 = prng.randint(0, 2, 51)


def test_compute_recall_equal() -> None:
"""Test that compute_recall give good result"""
Expand Down Expand Up @@ -122,7 +128,9 @@ def test_compute_precision_with_wrong_shape() -> None:
with pytest.raises(ValueError, match=r".*y_pred_proba should be a 3d*"):
_compute_risk_precision(lambdas, y_preds_proba.squeeze(), y_toy)
with pytest.raises(ValueError, match=r".*y should be a 2d*"):
_compute_risk_precision(lambdas, y_preds_proba, np.expand_dims(y_toy, 2))
_compute_risk_precision(
lambdas, y_preds_proba, np.expand_dims(y_toy, 2)
)
with pytest.raises(ValueError, match=r".*could not be broadcast*"):
_compute_risk_precision(lambdas, y_preds_proba, y_toy[:-1])

Expand Down Expand Up @@ -188,4 +196,4 @@ def test_invalid_shape_alpha_hb() -> None:
def test_delta_none_ltt(delta: Optional[float]) -> None:
"""Test error message when invalid delta"""
with pytest.raises(ValueError, match=r".*Invalid delta"):
_ltt_procedure(r_hat, alpha, delta, n)
_ltt_procedure(r_hat, alpha, delta, n)
2 changes: 1 addition & 1 deletion mapie/tests/test_multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_results_for_partial_fit(strategy: str) -> None:

@pytest.mark.parametrize("strategy", [*STRATEGIES])
@pytest.mark.parametrize(
"alpha", [np.array([0.3, 0.4]), [0.3, 0.4], (0.3, 0.4)]
"alpha", [np.array([0.05, 0.1]), [0.05, 0.1], (0.05, 0.1)]
)
def test_results_for_alpha_as_float_and_arraylike(
strategy: str, alpha: Any
Expand Down

0 comments on commit 5a9ffe4

Please sign in to comment.