Skip to content

Commit

Permalink
Merge branch 'master' into 310-learn-then-test
Browse files Browse the repository at this point in the history
  • Loading branch information
SZiane authored Jul 24, 2023
2 parents 41ad84e + cf93146 commit 3e7c5f2
Show file tree
Hide file tree
Showing 13 changed files with 1,000 additions and 381 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ History
------------------

* Add Learn-Then-Test method for multilabel-classification
* Refactor MapieRegressor and ConformityScore to add the possibility to use X in ConformityScore.
* Separate the handling of the estimator from MapieRegressor into a new class called EnsembleEstimator.
* Fix an unfixed random state in one of the classification tests

0.6.5 (2023-06-06)
Expand Down
18 changes: 11 additions & 7 deletions examples/regression/1-quickstart/plot_compare_conformity_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from mapie.metrics import regression_coverage_score
from mapie.regression import MapieRegressor

np.random.seed(0)
random_state = 42

# Parameters
features = [
Expand All @@ -50,7 +50,7 @@
"GarageArea",
]
alpha = 0.05
rf_kwargs = {"n_estimators": 10, "random_state": 0}
rf_kwargs = {"n_estimators": 10, "random_state": random_state}
model = RandomForestRegressor(**rf_kwargs)

##############################################################################
Expand All @@ -66,7 +66,7 @@
X, y = fetch_openml(name="house_prices", return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(
X[features], y, test_size=0.2
X[features], y, test_size=0.2, random_state=random_state
)

##############################################################################
Expand All @@ -87,9 +87,11 @@
##############################################################################
# First, train model with
# :class:`~mapie.conformity_scores.AbsoluteConformityScore`.
mapie = MapieRegressor(model)
mapie = MapieRegressor(model, random_state=random_state)
mapie.fit(X_train, y_train)
y_pred_absconfscore, y_pis_absconfscore = mapie.predict(X_test, alpha=alpha)
y_pred_absconfscore, y_pis_absconfscore = mapie.predict(
X_test, alpha=alpha, ensemble=True
)

coverage_absconfscore = regression_coverage_score(
y_test, y_pis_absconfscore[:, 0, 0], y_pis_absconfscore[:, 1, 0]
Expand Down Expand Up @@ -118,10 +120,12 @@ def get_yerr(y_pred, y_pis):
##############################################################################
# Then, train the model with
# :class:`~mapie.conformity_scores.GammaConformityScore`.
mapie = MapieRegressor(model, conformity_score=GammaConformityScore())
mapie = MapieRegressor(
model, conformity_score=GammaConformityScore(), random_state=random_state
)
mapie.fit(X_train, y_train)
y_pred_gammaconfscore, y_pis_gammaconfscore = mapie.predict(
X_test, alpha=[alpha]
X_test, alpha=[alpha], ensemble=True
)

coverage_gammaconfscore = regression_coverage_score(
Expand Down
222 changes: 193 additions & 29 deletions mapie/conformity_scores/conformity_scores.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from abc import ABCMeta, abstractmethod

import numpy as np
from typing import Tuple

from mapie._compatibility import np_nanquantile
from mapie._typing import ArrayLike, NDArray
from mapie.estimator.interface import EnsembleEstimator


class ConformityScore(metaclass=ABCMeta):
Expand Down Expand Up @@ -31,7 +34,9 @@ class ConformityScore(metaclass=ABCMeta):
- ``get_signed_conformity_scores``
The following equality must be verified:
``self.get_estimation_distribution(
y_pred, self.get_conformity_scores(y, y_pred)
X,
y_pred,
self.get_conformity_scores(X, y, y_pred)
) == y``
It should be specified if ``consistency_check==True``.
Expand All @@ -51,6 +56,7 @@ def __init__(
@abstractmethod
def get_signed_conformity_scores(
self,
X: ArrayLike,
y: ArrayLike,
y_pred: ArrayLike,
) -> NDArray:
Expand All @@ -63,47 +69,62 @@ def get_signed_conformity_scores(
Parameters
----------
y: NDArray
Observed values.
X: ArrayLike of shape (n_samples, n_features)
Observed feature values.
y: ArrayLike of shape (n_samples,)
Observed target values.
y_pred: NDArray
Predicted values.
y_pred: ArrayLike of shape (n_samples,)
Predicted target values.
Returns
-------
NDArray
Unsigned conformity scores.
NDArray of shape (n_samples,)
Signed conformity scores.
"""

@abstractmethod
def get_estimation_distribution(
self,
X: ArrayLike,
y_pred: ArrayLike,
conformity_scores: ArrayLike,
conformity_scores: ArrayLike
) -> NDArray:
"""
Placeholder for ``get_estimation_distribution``.
Subclasses should implement this method!
Compute samples of the estimation distribution from the predicted
values and the conformity scores.
targets and ``conformity_scores`` that can be either the conformity
scores or the quantile of the conformity scores.
Parameters
----------
y_pred: NDArray
Predicted values.
X: ArrayLike of shape (n_samples, n_features)
Observed feature values.
conformity_scores: NDArray
Conformity scores.
y_pred: ArrayLike
The shape is either (n_samples, n_references): when the
method is called in ``get_bounds`` it needs a prediction per train
sample for each test sample to compute the bounds.
Or (n_samples,): when it is called in ``check_consistency``
conformity_scores: ArrayLike
The shape is either (n_samples, 1) when it is the
conformity scores themselves or (1, n_alpha) when it is only the
quantile of the conformity scores.
Returns
-------
NDArray
NDArray of shape (n_samples, n_alpha) or
(n_samples, n_references) according to the shape of ``y_pred``
Observed values.
"""

def check_consistency(
self,
X: ArrayLike,
y: ArrayLike,
y_pred: ArrayLike,
conformity_scores: ArrayLike,
Expand All @@ -114,24 +135,32 @@ def check_consistency(
The following equality should be verified:
``self.get_estimation_distribution(
y_pred, self.get_conformity_scores(y, y_pred)
X,
y_pred,
self.get_conformity_scores(X, y, y_pred)
) == y``
Parameters
----------
y: NDArray
Observed values.
X: ArrayLike of shape (n_samples, n_features)
Observed feature values.
y: ArrayLike of shape (n_samples,)
Observed target values.
y_pred: NDArray
Predicted values.
y_pred: ArrayLike of shape (n_samples,)
Predicted target values.
conformity_scores: ArrayLike of shape (n_samples,)
Conformity scores.
Raises
------
ValueError
If the two methods are not consistent.
"""
score_distribution = self.get_estimation_distribution(
y_pred, conformity_scores
X, y_pred, conformity_scores
)
abs_conformity_scores = np.abs(np.subtract(score_distribution, y))
max_conf_score = np.max(abs_conformity_scores)
Expand All @@ -141,15 +170,16 @@ def check_consistency(
"get_estimation_distribution of the ConformityScore class "
"are not consistent. "
"The following equation must be verified: "
"self.get_estimation_distribution(y_pred, "
"self.get_conformity_scores(y, y_pred)) == y. " # noqa: E501
"self.get_estimation_distribution(X, y_pred, "
"self.get_conformity_scores(X, y, y_pred)) == y" # noqa: E501
f"The maximum conformity score is {max_conf_score}."
"The eps attribute may need to be increased if you are "
"sure that the two methods are consistent."
)

def get_conformity_scores(
self,
X: ArrayLike,
y: ArrayLike,
y_pred: ArrayLike,
) -> NDArray:
Expand All @@ -158,20 +188,154 @@ def get_conformity_scores(
Parameters
----------
y: NDArray
Observed values.
X: NDArray of shape (n_samples, n_features)
Observed feature values.
y: NDArray of shape (n_samples,)
Observed target values.
y_pred: NDArray
Predicted values.
y_pred: NDArray of shape (n_samples,)
Predicted target values.
Returns
-------
NDArray
NDArray of shape (n_samples,)
Conformity scores.
"""
conformity_scores = self.get_signed_conformity_scores(y, y_pred)
conformity_scores = self.get_signed_conformity_scores(X, y, y_pred)
if self.consistency_check:
self.check_consistency(y, y_pred, conformity_scores)
self.check_consistency(X, y, y_pred, conformity_scores)
if self.sym:
conformity_scores = np.abs(conformity_scores)
return conformity_scores

@staticmethod
def get_quantile(
conformity_scores: NDArray,
alpha_np: NDArray,
axis: int,
method: str
) -> NDArray:
"""
Compute the alpha quantile of the conformity scores or the conformity
scores aggregated with the predictions.
Parameters
----------
conformity_scores: NDArray of shape (n_samples,) or
(n_samples, n_references)
Values from which the quantile is computed, it can be the
conformity scores or the conformity scores aggregated with
the predictions.
alpha_np: NDArray of shape (n_alpha,)
NDArray of floats between ``0`` and ``1``, represents the
uncertainty of the confidence interval.
axis: int
The axis from which to compute the quantile.
method: str
``"higher"`` or ``"lower"`` the method to compute the quantile.
Returns
-------
NDArray of shape (1, n_alpha) or (n_samples, n_alpha)
The quantile of the conformity scores.
"""
quantile = np.column_stack([
np_nanquantile(
conformity_scores.astype(float),
_alpha,
axis=axis,
method=method
)
for _alpha in alpha_np
])
return quantile

def get_bounds(
self,
X: ArrayLike,
estimator: EnsembleEstimator,
conformity_scores: NDArray,
alpha_np: NDArray,
ensemble: bool,
method: str
) -> Tuple[NDArray, NDArray, NDArray]:
"""
Compute bounds of the prediction intervals from the observed values,
the estimator of type ``EnsembleEstimator`` and the conformity scores.
Parameters
----------
X: ArrayLike of shape (n_samples, n_features)
Observed feature values.
estimator: EnsembleEstimator
Estimator that is fitted to predict y from X.
conformity_scores: ArrayLike of shape (n_samples,)
Conformity scores.
alpha_np: NDArray of shape (n_alpha,)
NDArray of floats between ``0`` and ``1``, represents the
uncertainty of the confidence interval.
ensemble: bool
Boolean determining whether the predictions are ensembled or not.
method: str
Method to choose for prediction interval estimates.
The ``"plus"`` method implies that the quantile is calculated
after estimating the bounds, whereas the other methods
(among the ``"naive"``, ``"base"`` or ``"minmax"`` methods,
for example) do the opposite.
Returns
-------
Tuple[NDArray, NDArray, NDArray]
- The predictions itself. (y_pred) of shape (n_samples,).
- The lower bounds of the prediction intervals of shape
(n_samples, n_alpha).
- The upper bounds of the prediction intervals of shape
(n_samples, n_alpha).
"""
y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble)
signed = -1 if self.sym else 1

if method == "plus":
alpha_low = alpha_np if self.sym else alpha_np / 2
alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2

conformity_scores_low = self.get_estimation_distribution(
X, y_pred_low, signed * conformity_scores
)
conformity_scores_up = self.get_estimation_distribution(
X, y_pred_up, conformity_scores
)
bound_low = self.get_quantile(
conformity_scores_low, alpha_low, axis=1, method="lower"
)
bound_up = self.get_quantile(
conformity_scores_up, alpha_up, axis=1, method="higher"
)
else:
quantile_search = "higher" if self.sym else "lower"
alpha_low = 1 - alpha_np if self.sym else alpha_np / 2
alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2

quantile_low = self.get_quantile(
conformity_scores, alpha_low, axis=0, method=quantile_search
)
quantile_up = self.get_quantile(
conformity_scores, alpha_up, axis=0, method="higher"
)
bound_low = self.get_estimation_distribution(
X, y_pred_low, signed * quantile_low
)
bound_up = self.get_estimation_distribution(
X, y_pred_up, quantile_up
)

return y_pred, bound_low, bound_up
Loading

0 comments on commit 3e7c5f2

Please sign in to comment.