Skip to content

Commit

Permalink
UPD: docstring/whitespace/blank line
Browse files Browse the repository at this point in the history
  • Loading branch information
SZiane committed Jul 20, 2023
1 parent c01878f commit 466cbd2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
12 changes: 5 additions & 7 deletions mapie/control_risk/crc_rcps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Tuple


def _get_r_hat_plus(
def get_r_hat_plus(
risks: NDArray,
lambdas: NDArray,
method: Optional[str],
Expand All @@ -16,7 +16,6 @@ def _get_r_hat_plus(
Parameters
----------
risks: ArrayLike of shape (n_samples_cal, n_lambdas)
The risk for each observation for each threshold
Expand All @@ -28,16 +27,15 @@ def _get_r_hat_plus(
Correspond to the method use to control recall
score. Could be either CRC or RCPS.
bound: str
bound: Optional[str]
Bounds to compute. Either hoeffding, bernstein or wsr.
delta: float
delta: Optional[float]
Level of confidence.
sigma_init : float, optional
sigma_init : Optional[float]
First variance in the sigma_hat array. The default
value is the same as in the paper implementation [1].
By default .25
Returns
-------
Expand Down Expand Up @@ -150,7 +148,7 @@ def _get_r_hat_plus(
return r_hat, r_hat_plus


def _find_lambda_star(
def find_lambda_star(
lambdas: NDArray,
r_hat_plus: NDArray,
alpha_np: NDArray
Expand Down
3 changes: 0 additions & 3 deletions mapie/control_risk/ltt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ def ltt_procedure(
) -> Tuple[List[List[Any]], NDArray]:
"""
Apply the Learn-Then-Test procedure for risk control.
This will apply learn then test procedure for
risk control.
Note that we will do a multiple test for ``r_hat`` that are
less than level ``alpha_np``.
The procedure follows the instructions in [1]:
- Calculate p-values for each lambdas descretized
- Apply a family wise error rate algorithm,
Expand Down
14 changes: 9 additions & 5 deletions mapie/multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from .control_risk.ltt import find_lambda_control_star, ltt_procedure
from .control_risk.risks import compute_risk_precision, compute_risk_recall
from .control_risk.crc_rcps import _get_r_hat_plus
from .control_risk.crc_rcps import _find_lambda_star
from .control_risk.crc_rcps import get_r_hat_plus
from .control_risk.crc_rcps import find_lambda_star


class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
Expand Down Expand Up @@ -122,6 +122,10 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
[2] Angelopoulos, Anastasios N., Stephen, Bates, Adam, Fisch, Lihua,
Lei, and Tal, Schuster. "Conformal Risk Control." (2022).
[3] Angelopoulos, A. N., Bates, S., Candès, E. J., Jordan,
M. I., & Lei, L. (2021). Learn then test:
"Calibrating predictive algorithms to achieve risk control".
Examples
--------
>>> import numpy as np
Expand Down Expand Up @@ -152,7 +156,7 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
"single_estimator_",
"risks"
]
sigma_init = 0.25
sigma_init = 0.25 # Value given in the paper [1]
cal_size = .3

def __init__(
Expand Down Expand Up @@ -680,11 +684,11 @@ def predict(
)

else:
self.r_hat, self.r_hat_plus = _get_r_hat_plus(
self.r_hat, self.r_hat_plus = get_r_hat_plus(
self.risks, self.lambdas, self.method,
bound, delta, self.sigma_init
)
self.lambdas_star = _find_lambda_star(
self.lambdas_star = find_lambda_star(
self.lambdas, self.r_hat_plus, alpha_np
)
y_pred_proba_array = (
Expand Down

0 comments on commit 466cbd2

Please sign in to comment.