diff --git a/mapie/control_risk/ltt.py b/mapie/control_risk/ltt.py index 6fc4a0a6..216c24f6 100644 --- a/mapie/control_risk/ltt.py +++ b/mapie/control_risk/ltt.py @@ -18,7 +18,7 @@ def _ltt_procedure( Apply the Learn-Then-Test procedure for risk control. This procedure is called in ``MapieMultiLabelClassifier`` if ``metric=precision``. - This will apply learn then test procedure for + This will apply learn then test procedure for precision control. Note that we will do a multiple test for ``r_hat`` that are less than level ``alpha_np``. diff --git a/mapie/multi_label_classification.py b/mapie/multi_label_classification.py index 9908cf14..f9d53966 100644 --- a/mapie/multi_label_classification.py +++ b/mapie/multi_label_classification.py @@ -3,8 +3,6 @@ import warnings from typing import Iterable, Optional, Sequence, Tuple, Union, cast -import itertools - import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.linear_model import LogisticRegression @@ -99,7 +97,7 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin): lambdas_star: ArrayLike of shape (n_lambdas) Optimal threshold for a given alpha. - + valid_index: List[List[Any]] List of list of all index that satisfy fwer controlling for learn then test procedure. This attribute is compute diff --git a/mapie/tests/test_control_risk.py b/mapie/tests/test_control_risk.py index 79c9174d..481035ea 100644 --- a/mapie/tests/test_control_risk.py +++ b/mapie/tests/test_control_risk.py @@ -66,6 +66,12 @@ wrong_delta = None +random_state = 42 +prng = np.random.RandomState(random_state) +y_1 = prng.random(51) +y_2 = prng.random((51, 5)) +y_3 = prng.randint(0, 2, 51) + def test_compute_recall_equal() -> None: """Test that compute_recall give good result""" @@ -122,7 +128,9 @@ def test_compute_precision_with_wrong_shape() -> None: with pytest.raises(ValueError, match=r".*y_pred_proba should be a 3d*"): _compute_risk_precision(lambdas, y_preds_proba.squeeze(), y_toy) with pytest.raises(ValueError, match=r".*y should be a 2d*"): - _compute_risk_precision(lambdas, y_preds_proba, np.expand_dims(y_toy, 2)) + _compute_risk_precision( + lambdas, y_preds_proba, np.expand_dims(y_toy, 2) + ) with pytest.raises(ValueError, match=r".*could not be broadcast*"): _compute_risk_precision(lambdas, y_preds_proba, y_toy[:-1]) @@ -188,4 +196,4 @@ def test_invalid_shape_alpha_hb() -> None: def test_delta_none_ltt(delta: Optional[float]) -> None: """Test error message when invalid delta""" with pytest.raises(ValueError, match=r".*Invalid delta"): - _ltt_procedure(r_hat, alpha, delta, n) \ No newline at end of file + _ltt_procedure(r_hat, alpha, delta, n) diff --git a/mapie/tests/test_multi_label_classification.py b/mapie/tests/test_multi_label_classification.py index 1d40c675..39f96623 100644 --- a/mapie/tests/test_multi_label_classification.py +++ b/mapie/tests/test_multi_label_classification.py @@ -324,7 +324,7 @@ def test_results_for_partial_fit(strategy: str) -> None: @pytest.mark.parametrize("strategy", [*STRATEGIES]) @pytest.mark.parametrize( - "alpha", [np.array([0.3, 0.4]), [0.3, 0.4], (0.3, 0.4)] + "alpha", [np.array([0.05, 0.1]), [0.05, 0.1], (0.05, 0.1)] ) def test_results_for_alpha_as_float_and_arraylike( strategy: str, alpha: Any