Skip to content

Commit

Permalink
ENH: Add contribution/ random state
Browse files Browse the repository at this point in the history
  • Loading branch information
SZiane committed Jul 11, 2023
1 parent e7078ca commit 97edf2f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ Contributors
* Arnaud Capitaine <[email protected]>
* Tarik Tazi <[email protected]>
* Daniel Herbst <[email protected]>
* Sofiane Ziane <[email protected]>

To be continued ...
8 changes: 7 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ History
* Add split conformal option for regression and classification
* Update check method for calibration


0.6.4(2023-07-11)
------------------

* Add LTT for multilabel-classification

0.6.4 (2023-04-05)
------------------

Expand All @@ -34,7 +40,7 @@ History
0.6.0 (2023-01-19)
------------------

* Add RCPS and CRC for multilabel-classifcation
* Add RCPS and CRC for multilabel-classification
* Add Top-Label calibration
* Fix bug for classification with very low scores

Expand Down
2 changes: 1 addition & 1 deletion doc/theoretical_description_multilabel_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ on the recall. RCPS, LTT and CRC give three slightly different guarantees:
.. math::
\mathbb{P}(R(\mathcal{T}_{\lambda_{\lambda\in\hat{\Lambda}}) \leq \alpha ) \geq 1 - \delta
Notice that at the opposite of the other two methods, LTT allows to control any non-monotone loss. In Mapie for multilabel classification,
Notice that at the opposite of the other two methods, LTT allows to control any non-monotone loss. In MAPIE for multilabel classification,
we use CRC and RCPS for recall control and LTT for precision control.

1. Risk-Controlling Prediction Sets
Expand Down
2 changes: 1 addition & 1 deletion mapie/control_risk/risks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _compute_risk_recall(
y: NDArray
) -> NDArray:
"""
In `MapieMultiLabelClassifier` when`metric_control=recall`,
In `MapieMultiLabelClassifier` when `metric_control=recall`,
compute the recall per observation for each different
thresholds lambdas.
Expand Down
22 changes: 13 additions & 9 deletions mapie/tests/test_multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
METHODS = ["crc", "rcps", "ltt"]
METRICS = ['recall', 'precision']
BOUNDS = ["wsr", "hoeffding", "bernstein"]
random_state = 42

WRONG_METHODS = ["rpcs", "rcr", "test", "llt"]
WRONG_BOUNDS = ["wrs", "hoeff", "test", "", 1, 2.5, (1, 2)]
Expand All @@ -40,39 +41,39 @@
Params(
method="crc",
bound=None,
random_state=42,
random_state=random_state,
metric_control="recall"
),
),
"rcps_wsr": (
Params(
method="rcps",
bound="wsr",
random_state=42,
random_state=random_state,
metric_control='recall'
),
),
"rcps_hoeffding": (
Params(
method="rcps",
bound="hoeffding",
random_state=42,
random_state=random_state,
metric_control='recall'
),
),
"rcps_bernstein": (
Params(
method="rcps",
bound="bernstein",
random_state=42,
random_state=random_state,
metric_control='recall'
),
),
"ltt": (
Params(
method="ltt",
bound=None,
random_state=42,
random_state=random_state,
metric_control='precision'
),
),
Expand Down Expand Up @@ -218,17 +219,20 @@ def test_valid_method() -> None:
def test_valid_metric_method(strategy: str) -> None:
"""Test that valid metric raise no errors"""
args = STRATEGIES[strategy][0]
mapie_clf = MapieMultiLabelClassifier(random_state=42,
metric_control=args["metric_control"]
)
mapie_clf = MapieMultiLabelClassifier(
random_state=42,
metric_control=args["metric_control"]
)
mapie_clf.fit(X_toy, y_toy)
check_is_fitted(mapie_clf, mapie_clf.fit_attributes)


@pytest.mark.parametrize("bound", BOUNDS)
def test_valid_bound(bound: str) -> None:
"""Test that valid methods raise no errors."""
mapie_clf = MapieMultiLabelClassifier(random_state=42, method="rcps")
mapie_clf = MapieMultiLabelClassifier(
random_state=random_state, method="rcps"
)
mapie_clf.fit(X_toy, y_toy)
mapie_clf.predict(X_toy, bound=bound, delta=.1)
check_is_fitted(mapie_clf, mapie_clf.fit_attributes)
Expand Down

0 comments on commit 97edf2f

Please sign in to comment.