From 97edf2f7edb254264e6dfd47d86fbde08e51d514 Mon Sep 17 00:00:00 2001 From: sofiane Date: Tue, 11 Jul 2023 11:31:25 +0200 Subject: [PATCH] ENH: Add contribution/ random state --- AUTHORS.rst | 1 + HISTORY.rst | 8 ++++++- ..._description_multilabel_classification.rst | 2 +- mapie/control_risk/risks.py | 2 +- .../tests/test_multi_label_classification.py | 22 +++++++++++-------- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 17763372..2d9f0403 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -30,5 +30,6 @@ Contributors * Arnaud Capitaine * Tarik Tazi * Daniel Herbst +* Sofiane Ziane To be continued ... diff --git a/HISTORY.rst b/HISTORY.rst index 37ce7df6..e09ae0d2 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -10,6 +10,12 @@ History * Add split conformal option for regression and classification * Update check method for calibration + +0.6.4(2023-07-11) +------------------ + +* Add LTT for multilabel-classification + 0.6.4 (2023-04-05) ------------------ @@ -34,7 +40,7 @@ History 0.6.0 (2023-01-19) ------------------ -* Add RCPS and CRC for multilabel-classifcation +* Add RCPS and CRC for multilabel-classification * Add Top-Label calibration * Fix bug for classification with very low scores diff --git a/doc/theoretical_description_multilabel_classification.rst b/doc/theoretical_description_multilabel_classification.rst index f8aff6a2..ed2bf1b8 100644 --- a/doc/theoretical_description_multilabel_classification.rst +++ b/doc/theoretical_description_multilabel_classification.rst @@ -32,7 +32,7 @@ on the recall. RCPS, LTT and CRC give three slightly different guarantees: .. math:: \mathbb{P}(R(\mathcal{T}_{\lambda_{\lambda\in\hat{\Lambda}}) \leq \alpha ) \geq 1 - \delta -Notice that at the opposite of the other two methods, LTT allows to control any non-monotone loss. In Mapie for multilabel classification, +Notice that at the opposite of the other two methods, LTT allows to control any non-monotone loss. In MAPIE for multilabel classification, we use CRC and RCPS for recall control and LTT for precision control. 1. Risk-Controlling Prediction Sets diff --git a/mapie/control_risk/risks.py b/mapie/control_risk/risks.py index f396c7d2..e4f13fc9 100644 --- a/mapie/control_risk/risks.py +++ b/mapie/control_risk/risks.py @@ -12,7 +12,7 @@ def _compute_risk_recall( y: NDArray ) -> NDArray: """ - In `MapieMultiLabelClassifier` when`metric_control=recall`, + In `MapieMultiLabelClassifier` when `metric_control=recall`, compute the recall per observation for each different thresholds lambdas. diff --git a/mapie/tests/test_multi_label_classification.py b/mapie/tests/test_multi_label_classification.py index 39f96623..401ecf3f 100644 --- a/mapie/tests/test_multi_label_classification.py +++ b/mapie/tests/test_multi_label_classification.py @@ -29,6 +29,7 @@ METHODS = ["crc", "rcps", "ltt"] METRICS = ['recall', 'precision'] BOUNDS = ["wsr", "hoeffding", "bernstein"] +random_state = 42 WRONG_METHODS = ["rpcs", "rcr", "test", "llt"] WRONG_BOUNDS = ["wrs", "hoeff", "test", "", 1, 2.5, (1, 2)] @@ -40,7 +41,7 @@ Params( method="crc", bound=None, - random_state=42, + random_state=random_state, metric_control="recall" ), ), @@ -48,7 +49,7 @@ Params( method="rcps", bound="wsr", - random_state=42, + random_state=random_state, metric_control='recall' ), ), @@ -56,7 +57,7 @@ Params( method="rcps", bound="hoeffding", - random_state=42, + random_state=random_state, metric_control='recall' ), ), @@ -64,7 +65,7 @@ Params( method="rcps", bound="bernstein", - random_state=42, + random_state=random_state, metric_control='recall' ), ), @@ -72,7 +73,7 @@ Params( method="ltt", bound=None, - random_state=42, + random_state=random_state, metric_control='precision' ), ), @@ -218,9 +219,10 @@ def test_valid_method() -> None: def test_valid_metric_method(strategy: str) -> None: """Test that valid metric raise no errors""" args = STRATEGIES[strategy][0] - mapie_clf = MapieMultiLabelClassifier(random_state=42, - metric_control=args["metric_control"] - ) + mapie_clf = MapieMultiLabelClassifier( + random_state=42, + metric_control=args["metric_control"] + ) mapie_clf.fit(X_toy, y_toy) check_is_fitted(mapie_clf, mapie_clf.fit_attributes) @@ -228,7 +230,9 @@ def test_valid_metric_method(strategy: str) -> None: @pytest.mark.parametrize("bound", BOUNDS) def test_valid_bound(bound: str) -> None: """Test that valid methods raise no errors.""" - mapie_clf = MapieMultiLabelClassifier(random_state=42, method="rcps") + mapie_clf = MapieMultiLabelClassifier( + random_state=random_state, method="rcps" + ) mapie_clf.fit(X_toy, y_toy) mapie_clf.predict(X_toy, bound=bound, delta=.1) check_is_fitted(mapie_clf, mapie_clf.fit_attributes)